Merge branch 'master' into feat/vault

pull/798/head
Mariano Cano 2 years ago committed by GitHub
commit 37b521ec6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
go: [ '1.15', '1.16', '1.17' ]
go: [ '1.17', '1.18' ]
outputs:
is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }}
steps:
@ -33,7 +33,7 @@ jobs:
uses: golangci/golangci-lint-action@v2
with:
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
version: 'v1.44.0'
version: 'v1.45.0'
# Optional: working directory, useful for monorepos
# working-directory: somedir
@ -106,7 +106,7 @@ jobs:
name: Set up Go
uses: actions/setup-go@v2
with:
go-version: 1.17
go-version: 1.18
-
name: APT Install
id: aptInstall
@ -159,7 +159,7 @@ jobs:
name: Setup Go
uses: actions/setup-go@v2
with:
go-version: '1.17'
go-version: '1.18'
-
name: Install cosign
uses: sigstore/cosign-installer@v1.1.0

@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
go: [ '1.16', '1.17' ]
go: [ '1.17', '1.18' ]
steps:
-
name: Checkout
@ -33,7 +33,7 @@ jobs:
uses: golangci/golangci-lint-action@v2
with:
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
version: 'v1.44.0'
version: 'v1.45.0'
# Optional: working directory, useful for monorepos
# working-directory: somedir
@ -58,7 +58,7 @@ jobs:
run: V=1 make ci
-
name: Codecov
if: matrix.go == '1.17'
if: matrix.go == '1.18'
uses: codecov/codecov-action@v1.2.1
with:
file: ./coverage.out # optional

1
.gitignore vendored

@ -19,3 +19,4 @@ coverage.txt
output
vendor
.idea
.envrc

@ -19,6 +19,7 @@ builds:
- linux_386
- linux_amd64
- linux_arm64
- linux_arm_5
- linux_arm_6
- linux_arm_7
- windows_amd64
@ -39,6 +40,7 @@ builds:
- linux_386
- linux_amd64
- linux_arm64
- linux_arm_5
- linux_arm_6
- linux_arm_7
- windows_amd64
@ -59,6 +61,7 @@ builds:
- linux_386
- linux_amd64
- linux_arm64
- linux_arm_5
- linux_arm_6
- linux_arm_7
- windows_amd64

@ -4,16 +4,32 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
## [Unreleased - 0.18.2] - DATE
## [Unreleased - 0.18.3] - DATE
### Added
- Added support for renew after expiry using the claim `allowRenewAfterExpiry`.
- Added support for `extraNames` in X.509 templates.
### Changed
- IPv6 addresses are normalized as IP addresses instead of hostnames.
- More descriptive JWK decryption error message.
- Made SCEP CA URL paths dynamic
- Support two latest versions of Go (1.17, 1.18)
### Deprecated
### Removed
### Fixed
### Security
## [0.18.2] - 2022-03-01
### Added
- Added `subscriptionIDs` and `objectIDs` filters to the Azure provisioner.
- [NoSQL](https://github.com/smallstep/nosql/pull/21) package allows filtering
out database drivers using Go tags. For example, using the Go flag
`--tags=nobadger,nobbolt,nomysql` will only compile `step-ca` with the pgx
driver for PostgreSQL.
### Changed
- IPv6 addresses are normalized as IP addresses instead of hostnames.
- More descriptive JWK decryption error message.
- Make the X5C leaf certificate available to the templates using `{{ .AuthorizationCrt }}`.
### Fixed
- During provisioner add - validate provisioner configuration before storing to DB.
## [0.18.1] - 2022-02-03
### Added
- Support for ACME revocation.

@ -5,8 +5,9 @@ import (
"net/http"
"github.com/go-chi/chi"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/logging"
)
@ -70,23 +71,23 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
payload, err := payloadFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
var nar NewAccountRequest
if err := json.Unmarshal(payload.value, &nar); err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err,
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal new-account request payload"))
return
}
if err := nar.Validate(); err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
prov, err := acmeProvisionerFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
@ -96,26 +97,26 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
acmeErr, ok := err.(*acme.Error)
if !ok || acmeErr.Status != http.StatusBadRequest {
// Something went wrong ...
api.WriteError(w, err)
render.Error(w, err)
return
}
// Account does not exist //
if nar.OnlyReturnExisting {
api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType,
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType,
"account does not exist"))
return
}
jwk, err := jwkFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
eak, err := h.validateExternalAccountBinding(ctx, &nar)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
@ -125,18 +126,18 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
Status: acme.StatusValid,
}
if err := h.db.CreateAccount(ctx, acc); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error creating account"))
render.Error(w, acme.WrapErrorISE(err, "error creating account"))
return
}
if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response
err := eak.BindTo(acc)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error updating external account binding key"))
render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key"))
return
}
acc.ExternalAccountBinding = nar.ExternalAccountBinding
@ -149,7 +150,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
h.linker.LinkAccount(ctx, acc)
w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID))
api.JSONStatus(w, acc, httpStatus)
render.JSONStatus(w, acc, httpStatus)
}
// GetOrUpdateAccount is the api for updating an ACME account.
@ -157,12 +158,12 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
acc, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
payload, err := payloadFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
@ -171,12 +172,12 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
if !payload.isPostAsGet {
var uar UpdateAccountRequest
if err := json.Unmarshal(payload.value, &uar); err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err,
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal new-account request payload"))
return
}
if err := uar.Validate(); err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
if len(uar.Status) > 0 || len(uar.Contact) > 0 {
@ -187,7 +188,7 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
}
if err := h.db.UpdateAccount(ctx, acc); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error updating account"))
render.Error(w, acme.WrapErrorISE(err, "error updating account"))
return
}
}
@ -196,7 +197,7 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
h.linker.LinkAccount(ctx, acc)
w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID))
api.JSON(w, acc)
render.JSON(w, acc)
}
func logOrdersByAccount(w http.ResponseWriter, oids []string) {
@ -213,22 +214,22 @@ func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
acc, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
accID := chi.URLParam(r, "accID")
if acc.ID != accID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID))
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID))
return
}
orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
h.linker.LinkOrdersByAccountID(ctx, orders)
api.JSON(w, orders)
render.JSON(w, orders)
logOrdersByAccount(w, orders)
}

@ -1,6 +1,7 @@
package api
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
@ -11,8 +12,10 @@ import (
"time"
"github.com/go-chi/chi"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner"
)
@ -43,6 +46,7 @@ type Handler struct {
ca acme.CertificateAuthority
linker Linker
validateChallengeOptions *acme.ValidateChallengeOptions
prerequisitesChecker func(ctx context.Context) (bool, error)
}
// HandlerOptions required to create a new ACME API request handler.
@ -60,6 +64,9 @@ type HandlerOptions struct {
// "acme" is the prefix from which the ACME api is accessed.
Prefix string
CA acme.CertificateAuthority
// PrerequisitesChecker checks if all prerequisites for serving ACME are
// met by the CA configuration.
PrerequisitesChecker func(ctx context.Context) (bool, error)
}
// NewHandler returns a new ACME API handler.
@ -76,6 +83,13 @@ func NewHandler(ops HandlerOptions) api.RouterHandler {
dialer := &net.Dialer{
Timeout: 30 * time.Second,
}
prerequisitesChecker := func(ctx context.Context) (bool, error) {
// by default all prerequisites are met
return true, nil
}
if ops.PrerequisitesChecker != nil {
prerequisitesChecker = ops.PrerequisitesChecker
}
return &Handler{
ca: ops.CA,
db: ops.DB,
@ -88,6 +102,7 @@ func NewHandler(ops HandlerOptions) api.RouterHandler {
return tls.DialWithDialer(dialer, network, addr, config)
},
},
prerequisitesChecker: prerequisitesChecker,
}
}
@ -95,13 +110,13 @@ func NewHandler(ops HandlerOptions) api.RouterHandler {
func (h *Handler) Route(r api.Router) {
getPath := h.linker.GetUnescapedPathSuffix
// Standard ACME API
r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce)))))
r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce)))))
r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory)))
r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory)))
r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce))))))
r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce))))))
r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory))))
r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory))))
validatingMiddleware := func(next nextHTTP) nextHTTP {
return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next)))))))
return h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next))))))))
}
extractPayloadByJWK := func(next nextHTTP) nextHTTP {
return validatingMiddleware(h.extractJWK(h.verifyAndExtractJWSPayload(next)))
@ -168,11 +183,11 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
acmeProv, err := acmeProvisionerFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
api.JSON(w, &Directory{
render.JSON(w, &Directory{
NewNonce: h.linker.GetLink(ctx, NewNonceLinkType),
NewAccount: h.linker.GetLink(ctx, NewAccountLinkType),
NewOrder: h.linker.GetLink(ctx, NewOrderLinkType),
@ -187,7 +202,7 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
// NotImplemented returns a 501 and is generally a placeholder for functionality which
// MAY be added at some point in the future but is not in any way a guarantee of such.
func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) {
api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented"))
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented"))
}
// GetAuthorization ACME api for retrieving an Authz.
@ -195,28 +210,28 @@ func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
acc, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID"))
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving authorization"))
render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization"))
return
}
if acc.ID != az.AccountID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own authorization '%s'", acc.ID, az.ID))
return
}
if err = az.UpdateStatus(ctx, h.db); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error updating authorization status"))
render.Error(w, acme.WrapErrorISE(err, "error updating authorization status"))
return
}
h.linker.LinkAuthorization(ctx, az)
w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID))
api.JSON(w, az)
render.JSON(w, az)
}
// GetChallenge ACME api for retrieving a Challenge.
@ -224,14 +239,14 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
acc, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
// Just verify that the payload was set, since we're not strictly adhering
// to ACME V2 spec for reasons specified below.
_, err = payloadFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
@ -244,22 +259,22 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
azID := chi.URLParam(r, "authzID")
ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID)
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving challenge"))
render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge"))
return
}
ch.AuthorizationID = azID
if acc.ID != ch.AccountID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own challenge '%s'", acc.ID, ch.ID))
return
}
jwk, err := jwkFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error validating challenge"))
render.Error(w, acme.WrapErrorISE(err, "error validating challenge"))
return
}
@ -267,7 +282,7 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up"))
w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID))
api.JSON(w, ch)
render.JSON(w, ch)
}
// GetCertificate ACME api for retrieving a Certificate.
@ -275,18 +290,18 @@ func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
acc, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
certID := chi.URLParam(r, "certID")
cert, err := h.db.GetCertificate(ctx, certID)
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate"))
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate"))
return
}
if cert.AccountID != acc.ID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own certificate '%s'", acc.ID, certID))
return
}

@ -10,13 +10,14 @@ import (
"strings"
"github.com/go-chi/chi"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/nosql"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
)
type nextHTTP = func(http.ResponseWriter, *http.Request)
@ -64,7 +65,7 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
nonce, err := h.db.CreateNonce(r.Context())
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
w.Header().Set("Replay-Nonce", string(nonce))
@ -90,7 +91,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
var expected []string
p, err := provisionerFromContext(r.Context())
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
@ -110,7 +111,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
return
}
}
api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
render.Error(w, acme.NewError(acme.ErrorMalformedType,
"expected content-type to be in %s, but got %s", expected, ct))
}
}
@ -120,12 +121,12 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "failed to read request body"))
render.Error(w, acme.WrapErrorISE(err, "failed to read request body"))
return
}
jws, err := jose.ParseJWS(string(body))
if err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body"))
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body"))
return
}
ctx := context.WithValue(r.Context(), jwsContextKey, jws)
@ -153,15 +154,15 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
ctx := r.Context()
jws, err := jwsFromContext(r.Context())
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
if len(jws.Signatures) == 0 {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"))
render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"))
return
}
if len(jws.Signatures) > 1 {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"))
render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"))
return
}
@ -172,7 +173,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
len(uh.Algorithm) > 0 ||
len(uh.Nonce) > 0 ||
len(uh.ExtraHeaders) > 0 {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"))
render.Error(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"))
return
}
hdr := sig.Protected
@ -182,13 +183,13 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
switch k := hdr.JSONWebKey.Key.(type) {
case *rsa.PublicKey:
if k.Size() < keyutil.MinRSAKeyBytes {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
render.Error(w, acme.NewError(acme.ErrorMalformedType,
"rsa keys must be at least %d bits (%d bytes) in size",
8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes))
return
}
default:
api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
render.Error(w, acme.NewError(acme.ErrorMalformedType,
"jws key type and algorithm do not match"))
return
}
@ -196,35 +197,35 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA:
// we good
default:
api.WriteError(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm))
render.Error(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm))
return
}
// Check the validity/freshness of the Nonce.
if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
// Check that the JWS url matches the requested url.
jwsURL, ok := hdr.ExtraHeaders["url"].(string)
if !ok {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header"))
render.Error(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header"))
return
}
reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path}
if jwsURL != reqURL.String() {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
render.Error(w, acme.NewError(acme.ErrorMalformedType,
"url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL))
return
}
if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
return
}
if hdr.JSONWebKey == nil && hdr.KeyID == "" {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
render.Error(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
return
}
next(w, r)
@ -239,23 +240,23 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
ctx := r.Context()
jws, err := jwsFromContext(r.Context())
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
jwk := jws.Signatures[0].Protected.JSONWebKey
if jwk == nil {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"))
render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"))
return
}
if !jwk.Valid() {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"))
render.Error(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"))
return
}
// Overwrite KeyID with the JWK thumbprint.
jwk.KeyID, err = acme.KeyToID(jwk)
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error getting KeyID from JWK"))
render.Error(w, acme.WrapErrorISE(err, "error getting KeyID from JWK"))
return
}
@ -269,11 +270,11 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
// For NewAccount and Revoke requests ...
break
case err != nil:
api.WriteError(w, err)
render.Error(w, err)
return
default:
if !acc.IsValid() {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
return
}
ctx = context.WithValue(ctx, accContextKey, acc)
@ -283,25 +284,24 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
}
// lookupProvisioner loads the provisioner associated with the request.
// Responsds 404 if the provisioner does not exist.
// Responds 404 if the provisioner does not exist.
func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
nameEscaped := chi.URLParam(r, "provisionerID")
name, err := url.PathUnescape(nameEscaped)
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
render.Error(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
return
}
p, err := h.ca.LoadProvisionerByName(name)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
acmeProv, ok := p.(*provisioner.ACME)
if !ok {
api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
return
}
ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv))
@ -309,6 +309,24 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
}
}
// checkPrerequisites checks if all prerequisites for serving ACME
// are met by the CA configuration.
func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ok, err := h.prerequisitesChecker(ctx)
if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
return
}
if !ok {
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
return
}
next(w, r.WithContext(ctx))
}
}
// lookupJWK loads the JWK associated with the acme account referenced by the
// kid parameter of the signed payload.
// Make sure to parse and validate the JWS before running this middleware.
@ -317,14 +335,14 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
ctx := r.Context()
jws, err := jwsFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "")
kid := jws.Signatures[0].Protected.KeyID
if !strings.HasPrefix(kid, kidPrefix) {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
render.Error(w, acme.NewError(acme.ErrorMalformedType,
"kid does not have required prefix; expected %s, but got %s",
kidPrefix, kid))
return
@ -334,14 +352,14 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
acc, err := h.db.GetAccount(ctx, accID)
switch {
case nosql.IsErrNotFound(err):
api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
return
case err != nil:
api.WriteError(w, err)
render.Error(w, err)
return
default:
if !acc.IsValid() {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
return
}
ctx = context.WithValue(ctx, accContextKey, acc)
@ -359,7 +377,7 @@ func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP {
ctx := r.Context()
jws, err := jwsFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
@ -395,21 +413,21 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
ctx := r.Context()
jws, err := jwsFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
jwk, err := jwkFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
render.Error(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
return
}
payload, err := jws.Verify(jwk)
if err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws"))
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws"))
return
}
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{
@ -426,11 +444,11 @@ func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
payload, err := payloadFromContext(r.Context())
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
if !payload.isPostAsGet {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET"))
render.Error(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET"))
return
}
next(w, r)

@ -1656,3 +1656,91 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
})
}
}
func TestHandler_checkPrerequisites(t *testing.T) {
prov := newProv()
provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
u := fmt.Sprintf("%s/acme/%s/account/1234",
baseURL, provName)
type test struct {
linker Linker
ctx context.Context
prerequisitesChecker func(context.Context) (bool, error)
next func(http.ResponseWriter, *http.Request)
err *acme.Error
statusCode int
}
var tests = map[string]func(t *testing.T) test{
"fail/error": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
linker: NewLinker("dns", "acme"),
ctx: ctx,
prerequisitesChecker: func(context.Context) (bool, error) { return false, errors.New("force") },
next: func(w http.ResponseWriter, r *http.Request) {
w.Write(testBody)
},
err: acme.WrapErrorISE(errors.New("force"), "error checking acme provisioner prerequisites"),
statusCode: 500,
}
},
"fail/prerequisites-nok": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
linker: NewLinker("dns", "acme"),
ctx: ctx,
prerequisitesChecker: func(context.Context) (bool, error) { return false, nil },
next: func(w http.ResponseWriter, r *http.Request) {
w.Write(testBody)
},
err: acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"),
statusCode: 501,
}
},
"ok": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
linker: NewLinker("dns", "acme"),
ctx: ctx,
prerequisitesChecker: func(context.Context) (bool, error) { return true, nil },
next: func(w http.ResponseWriter, r *http.Request) {
w.Write(testBody)
},
statusCode: 200,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: nil, linker: tc.linker, prerequisitesChecker: tc.prerequisitesChecker}
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.checkPrerequisites(tc.next)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := io.ReadAll(res.Body)
res.Body.Close()
assert.FatalError(t, err)
if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
var ae acme.Error
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
assert.Equals(t, bytes.TrimSpace(body), testBody)
}
})
}
}

@ -11,9 +11,11 @@ import (
"time"
"github.com/go-chi/chi"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
"go.step.sm/crypto/randutil"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api/render"
)
// NewOrderRequest represents the body for a NewOrder request.
@ -70,28 +72,28 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
acc, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
prov, err := provisionerFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
payload, err := payloadFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
var nor NewOrderRequest
if err := json.Unmarshal(payload.value, &nor); err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err,
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal new-order request payload"))
return
}
if err := nor.Validate(); err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
@ -116,7 +118,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
Status: acme.StatusPending,
}
if err := h.newAuthorization(ctx, az); err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
o.AuthorizationIDs[i] = az.ID
@ -135,14 +137,14 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
}
if err := h.db.CreateOrder(ctx, o); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error creating order"))
render.Error(w, acme.WrapErrorISE(err, "error creating order"))
return
}
h.linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
api.JSONStatus(w, o, http.StatusCreated)
render.JSONStatus(w, o, http.StatusCreated)
}
func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error {
@ -186,38 +188,38 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
acc, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
prov, err := provisionerFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order"))
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
return
}
if acc.ID != o.AccountID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own order '%s'", acc.ID, o.ID))
return
}
if prov.GetID() != o.ProvisionerID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
return
}
if err = o.UpdateStatus(ctx, h.db); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error updating order status"))
render.Error(w, acme.WrapErrorISE(err, "error updating order status"))
return
}
h.linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
api.JSON(w, o)
render.JSON(w, o)
}
// FinalizeOrder attemptst to finalize an order and create a certificate.
@ -225,54 +227,54 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
acc, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
prov, err := provisionerFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
payload, err := payloadFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
var fr FinalizeRequest
if err := json.Unmarshal(payload.value, &fr); err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err,
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal finalize-order request payload"))
return
}
if err := fr.Validate(); err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order"))
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
return
}
if acc.ID != o.AccountID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own order '%s'", acc.ID, o.ID))
return
}
if prov.GetID() != o.ProvisionerID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
return
}
if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error finalizing order"))
render.Error(w, acme.WrapErrorISE(err, "error finalizing order"))
return
}
h.linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
api.JSON(w, o)
render.JSON(w, o)
}
// challengeTypes determines the types of challenges that should be used

@ -10,13 +10,14 @@ import (
"net/http"
"strings"
"go.step.sm/crypto/jose"
"golang.org/x/crypto/ocsp"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/logging"
"go.step.sm/crypto/jose"
"golang.org/x/crypto/ocsp"
)
type revokePayload struct {
@ -30,65 +31,65 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
jws, err := jwsFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
prov, err := provisionerFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
payload, err := payloadFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
var p revokePayload
err = json.Unmarshal(payload.value, &p)
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error unmarshaling payload"))
render.Error(w, acme.WrapErrorISE(err, "error unmarshaling payload"))
return
}
certBytes, err := base64.RawURLEncoding.DecodeString(p.Certificate)
if err != nil {
// in this case the most likely cause is a client that didn't properly encode the certificate
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property"))
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property"))
return
}
certToBeRevoked, err := x509.ParseCertificate(certBytes)
if err != nil {
// in this case a client may have encoded something different than a certificate
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate"))
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate"))
return
}
serial := certToBeRevoked.SerialNumber.String()
dbCert, err := h.db.GetCertificateBySerial(ctx, serial)
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate by serial"))
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial"))
return
}
if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) {
// this should never happen
api.WriteError(w, acme.NewErrorISE("certificate raw bytes are not equal"))
render.Error(w, acme.NewErrorISE("certificate raw bytes are not equal"))
return
}
if shouldCheckAccountFrom(jws) {
account, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
if acmeErr != nil {
api.WriteError(w, acmeErr)
render.Error(w, acmeErr)
return
}
} else {
@ -97,26 +98,26 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
_, err := jws.Verify(certToBeRevoked.PublicKey)
if err != nil {
// TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized?
api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err))
render.Error(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err))
return
}
}
hasBeenRevokedBefore, err := h.ca.IsRevoked(serial)
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate"))
render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate"))
return
}
if hasBeenRevokedBefore {
api.WriteError(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked"))
render.Error(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked"))
return
}
reasonCode := p.ReasonCode
acmeErr := validateReasonCode(reasonCode)
if acmeErr != nil {
api.WriteError(w, acmeErr)
render.Error(w, acmeErr)
return
}
@ -124,14 +125,14 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
ctx = provisioner.NewContextWithMethod(ctx, provisioner.RevokeMethod)
err = prov.AuthorizeRevoke(ctx, "")
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner"))
render.Error(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner"))
return
}
options := revokeOptions(serial, certToBeRevoked, reasonCode)
err = h.ca.Revoke(ctx, options)
if err != nil {
api.WriteError(w, wrapRevokeErr(err))
render.Error(w, wrapRevokeErr(err))
return
}

@ -79,7 +79,7 @@ func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey,
}
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
u := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
resp, err := vo.HTTPGet(u.String())
if err != nil {
@ -119,6 +119,17 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb
return nil
}
// http01ChallengeHost checks if a Challenge value is an IPv6 address
// and adds square brackets if that's the case, so that it can be used
// as a hostname. Returns the original Challenge value as the host to
// use in other cases.
func http01ChallengeHost(value string) string {
if ip := net.ParseIP(value); ip != nil && ip.To4() == nil {
value = "[" + value + "]"
}
return value
}
func tlsAlert(err error) uint8 {
var opErr *net.OpError
if errors.As(err, &opErr) {

@ -13,6 +13,7 @@ import (
"encoding/asn1"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"io"
"math/big"
@ -23,9 +24,9 @@ import (
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"go.step.sm/crypto/jose"
"github.com/smallstep/assert"
)
func Test_storeError(t *testing.T) {
@ -2350,3 +2351,34 @@ func Test_serverName(t *testing.T) {
})
}
}
func Test_http01ChallengeHost(t *testing.T) {
tests := []struct {
name string
value string
want string
}{
{
name: "dns",
value: "www.example.com",
want: "www.example.com",
},
{
name: "ipv4",
value: "127.0.0.1",
want: "127.0.0.1",
},
{
name: "ipv6",
value: "::1",
want: "[::1]",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := http01ChallengeHost(tt.value); got != tt.want {
t.Errorf("http01ChallengeHost() = %v, want %v", got, tt.want)
}
})
}
}

@ -3,13 +3,10 @@ package acme
import (
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/api/render"
)
// ProblemType is the type of the ACME problem.
@ -353,26 +350,8 @@ func (e *Error) ToLog() (interface{}, error) {
return string(b), nil
}
// WriteError writes to w a JSON representation of the given error.
func WriteError(w http.ResponseWriter, err *Error) {
// Render implements render.RenderableError for Error.
func (e *Error) Render(w http.ResponseWriter) {
w.Header().Set("Content-Type", "application/problem+json")
w.WriteHeader(err.StatusCode())
// Write errors in the response writer
if rl, ok := w.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"error": err.Err,
})
if os.Getenv("STEPDEBUG") == "1" {
if e, ok := err.Err.(errs.StackTracer); ok {
rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e),
})
}
}
}
if err := json.NewEncoder(w).Encode(err); err != nil {
log.Println(err)
}
render.JSONStatus(w, e, e.StatusCode())
}

@ -20,6 +20,9 @@ import (
"github.com/go-chi/chi"
"github.com/pkg/errors"
"github.com/smallstep/certificates/api/log"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner"
@ -33,6 +36,7 @@ type Authority interface {
// context specifies the Authorize[Sign|Revoke|etc.] method.
Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error)
AuthorizeSign(ott string) ([]provisioner.SignOption, error)
AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error)
GetTLSOptions() *config.TLSOptions
Root(shasum string) (*x509.Certificate, error)
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
@ -43,7 +47,7 @@ type Authority interface {
GetProvisioners(cursor string, limit int) (provisioner.List, string, error)
Revoke(context.Context, *authority.RevokeOptions) error
GetEncryptedKey(kid string) (string, error)
GetRoots() (federation []*x509.Certificate, err error)
GetRoots() ([]*x509.Certificate, error)
GetFederation() ([]*x509.Certificate, error)
Version() authority.Version
}
@ -257,6 +261,7 @@ func (h *caHandler) Route(r Router) {
r.MethodFunc("GET", "/provisioners", h.Provisioners)
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey)
r.MethodFunc("GET", "/roots", h.Roots)
r.MethodFunc("GET", "/roots.pem", h.RootsPEM)
r.MethodFunc("GET", "/federation", h.Federation)
// SSH CA
r.MethodFunc("POST", "/ssh/sign", h.SSHSign)
@ -280,7 +285,7 @@ func (h *caHandler) Route(r Router) {
// Version is an HTTP handler that returns the version of the server.
func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) {
v := h.Authority.Version()
JSON(w, VersionResponse{
render.JSON(w, VersionResponse{
Version: v.Version,
RequireClientAuthentication: v.RequireClientAuthentication,
})
@ -288,7 +293,7 @@ func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) {
// Health is an HTTP handler that returns the status of the server.
func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) {
JSON(w, HealthResponse{Status: "ok"})
render.JSON(w, HealthResponse{Status: "ok"})
}
// Root is an HTTP handler that using the SHA256 from the URL, returns the root
@ -299,11 +304,11 @@ func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
// Load root certificate with the
cert, err := h.Authority.Root(sum)
if err != nil {
WriteError(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
return
}
JSON(w, &RootResponse{RootPEM: Certificate{cert}})
render.JSON(w, &RootResponse{RootPEM: Certificate{cert}})
}
func certChainToPEM(certChain []*x509.Certificate) []Certificate {
@ -318,16 +323,16 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := ParseCursor(r)
if err != nil {
WriteError(w, err)
render.Error(w, err)
return
}
p, next, err := h.Authority.GetProvisioners(cursor, limit)
if err != nil {
WriteError(w, errs.InternalServerErr(err))
render.Error(w, errs.InternalServerErr(err))
return
}
JSON(w, &ProvisionersResponse{
render.JSON(w, &ProvisionersResponse{
Provisioners: p,
NextCursor: next,
})
@ -338,17 +343,17 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
kid := chi.URLParam(r, "kid")
key, err := h.Authority.GetEncryptedKey(kid)
if err != nil {
WriteError(w, errs.NotFoundErr(err))
render.Error(w, errs.NotFoundErr(err))
return
}
JSON(w, &ProvisionerKeyResponse{key})
render.JSON(w, &ProvisionerKeyResponse{key})
}
// Roots returns all the root certificates for the CA.
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
roots, err := h.Authority.GetRoots()
if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error getting roots"))
render.Error(w, errs.ForbiddenErr(err, "error getting roots"))
return
}
@ -357,16 +362,39 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
certs[i] = Certificate{roots[i]}
}
JSONStatus(w, &RootsResponse{
render.JSONStatus(w, &RootsResponse{
Certificates: certs,
}, http.StatusCreated)
}
// RootsPEM returns all the root certificates for the CA in PEM format.
func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) {
roots, err := h.Authority.GetRoots()
if err != nil {
render.Error(w, errs.InternalServerErr(err))
return
}
w.Header().Set("Content-Type", "application/x-pem-file")
for _, root := range roots {
block := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: root.Raw,
})
if _, err := w.Write(block); err != nil {
log.Error(w, err)
return
}
}
}
// Federation returns all the public certificates in the federation.
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
federated, err := h.Authority.GetFederation()
if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error getting federated roots"))
render.Error(w, errs.ForbiddenErr(err, "error getting federated roots"))
return
}
@ -375,7 +403,7 @@ func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
certs[i] = Certificate{federated[i]}
}
JSONStatus(w, &FederationResponse{
render.JSONStatus(w, &FederationResponse{
Certificates: certs,
}, http.StatusCreated)
}

@ -13,6 +13,7 @@ import (
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
@ -27,14 +28,17 @@ import (
"github.com/go-chi/chi"
"github.com/pkg/errors"
"golang.org/x/crypto/ssh"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/x509util"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/templates"
"go.step.sm/crypto/jose"
"golang.org/x/crypto/ssh"
)
const (
@ -171,6 +175,7 @@ type mockAuthority struct {
ret1, ret2 interface{}
err error
authorizeSign func(ott string) ([]provisioner.SignOption, error)
authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error)
getTLSOptions func() *authority.TLSOptions
root func(shasum string) (*x509.Certificate, error)
sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
@ -208,6 +213,13 @@ func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, err
return m.ret1.([]provisioner.SignOption), m.err
}
func (m *mockAuthority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) {
if m.authorizeRenewToken != nil {
return m.authorizeRenewToken(ctx, ott)
}
return m.ret1.(*x509.Certificate), m.err
}
func (m *mockAuthority) GetTLSOptions() *authority.TLSOptions {
if m.getTLSOptions != nil {
return m.getTLSOptions()
@ -920,48 +932,141 @@ func Test_caHandler_Renew(t *testing.T) {
cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
}
// Prepare root and leaf for renew after expiry test.
now := time.Now()
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
leafPub, leafPriv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
root := &x509.Certificate{
Subject: pkix.Name{CommonName: "Test Root CA"},
PublicKey: rootPub,
KeyUsage: x509.KeyUsageCertSign,
BasicConstraintsValid: true,
IsCA: true,
NotBefore: now.Add(-2 * time.Hour),
NotAfter: now.Add(time.Hour),
}
root, err = x509util.CreateCertificate(root, root, rootPub, rootPriv)
if err != nil {
t.Fatal(err)
}
expiredLeaf := &x509.Certificate{
Subject: pkix.Name{CommonName: "Leaf certificate"},
PublicKey: leafPub,
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute),
EmailAddresses: []string{"test@example.org"},
}
expiredLeaf, err = x509util.CreateCertificate(expiredLeaf, root, leafPub, rootPriv)
if err != nil {
t.Fatal(err)
}
// Generate renew after expiry token
so := new(jose.SignerOptions)
so.WithType("JWT")
so.WithHeader("x5cInsecure", []string{base64.StdEncoding.EncodeToString(expiredLeaf.Raw)})
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.EdDSA, Key: leafPriv}, so)
if err != nil {
t.Fatal(err)
}
generateX5cToken := func(claims jose.Claims) string {
s, err := jose.Signed(sig).Claims(claims).CompactSerialize()
if err != nil {
t.Fatal(err)
}
return s
}
tests := []struct {
name string
tls *tls.ConnectionState
header http.Header
cert *x509.Certificate
root *x509.Certificate
err error
statusCode int
}{
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
{"no tls", nil, nil, nil, nil, http.StatusBadRequest},
{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest},
{"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden},
{"ok", cs, nil, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
{"ok renew after expiry", &tls.ConnectionState{}, http.Header{
"Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{
NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
})},
}, expiredLeaf, root, nil, http.StatusCreated},
{"no tls", nil, nil, nil, nil, nil, http.StatusBadRequest},
{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, nil, http.StatusBadRequest},
{"renew error", cs, nil, nil, nil, errs.Forbidden("an error"), http.StatusForbidden},
{"fail expired token", &tls.ConnectionState{}, http.Header{
"Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{
NotBefore: jose.NewNumericDate(now.Add(-time.Hour)), Expiry: jose.NewNumericDate(now.Add(-time.Minute)),
})},
}, expiredLeaf, root, errs.Forbidden("an error"), http.StatusUnauthorized},
{"fail invalid root", &tls.ConnectionState{}, http.Header{
"Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{
NotBefore: jose.NewNumericDate(now.Add(-time.Hour)), Expiry: jose.NewNumericDate(now.Add(-time.Minute)),
})},
}, expiredLeaf, parseCertificate(rootPEM), errs.Forbidden("an error"), http.StatusUnauthorized},
}
expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{
ret1: tt.cert, ret2: tt.root, err: tt.err,
authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) {
jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root})
if err != nil {
return nil, errs.Unauthorized(err.Error())
}
var claims jose.Claims
if err := jwt.Claims(chain[0][0].PublicKey, &claims); err != nil {
return nil, errs.Unauthorized(err.Error())
}
if err := claims.ValidateWithLeeway(jose.Expected{
Time: now,
}, time.Minute); err != nil {
return nil, errs.Unauthorized(err.Error())
}
return chain[0][0], nil
},
getTLSOptions: func() *authority.TLSOptions {
return nil
},
}).(*caHandler)
req := httptest.NewRequest("POST", "http://example.com/renew", nil)
req.TLS = tt.tls
req.Header = tt.header
w := httptest.NewRecorder()
h.Renew(logging.NewResponseLogger(w), req)
res := w.Result()
if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
}
res := w.Result()
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Errorf("caHandler.Renew unexpected error = %v", err)
}
if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
t.Errorf("%s", body)
}
if tt.statusCode < http.StatusBadRequest {
expected := []byte(`{"crt":"` + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.cert.Raw})), "\n", `\n`) + `",` +
`"ca":"` + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.root.Raw})), "\n", `\n`) + `",` +
`"certChain":["` +
strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.cert.Raw})), "\n", `\n`) + `","` +
strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.root.Raw})), "\n", `\n`) + `"]}`)
if !bytes.Equal(bytes.TrimSpace(body), expected) {
t.Errorf("caHandler.Root Body = %s, wants %s", body, expected)
t.Errorf("caHandler.Root Body = \n%s, wants \n%s", body, expected)
}
}
})
@ -1239,6 +1344,46 @@ func Test_caHandler_Roots(t *testing.T) {
}
}
func Test_caHandler_RootsPEM(t *testing.T) {
parsedRoot := parseCertificate(rootPEM)
tests := []struct {
name string
roots []*x509.Certificate
err error
statusCode int
expect string
}{
{"one root", []*x509.Certificate{parsedRoot}, nil, http.StatusOK, rootPEM},
{"two roots", []*x509.Certificate{parsedRoot, parsedRoot}, nil, http.StatusOK, rootPEM + "\n" + rootPEM},
{"fail", nil, errors.New("an error"), http.StatusInternalServerError, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ret1: tt.roots, err: tt.err}).(*caHandler)
req := httptest.NewRequest("GET", "https://example.com/roots", nil)
w := httptest.NewRecorder()
h.RootsPEM(w, req)
res := w.Result()
if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.RootsPEM StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
}
body, err := io.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Errorf("caHandler.RootsPEM unexpected error = %v", err)
}
if tt.statusCode < http.StatusBadRequest {
if !bytes.Equal(bytes.TrimSpace(body), []byte(tt.expect)) {
t.Errorf("caHandler.RootsPEM Body = %s, wants %s", body, tt.expect)
}
}
})
}
}
func Test_caHandler_Federation(t *testing.T) {
cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},

@ -1,64 +0,0 @@
package api
import (
"encoding/json"
"fmt"
"net/http"
"os"
"github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/scep"
)
// WriteError writes to w a JSON representation of the given error.
func WriteError(w http.ResponseWriter, err error) {
switch k := err.(type) {
case *acme.Error:
acme.WriteError(w, k)
return
case *admin.Error:
admin.WriteError(w, k)
return
case *scep.Error:
w.Header().Set("Content-Type", "text/plain")
default:
w.Header().Set("Content-Type", "application/json")
}
cause := errors.Cause(err)
if sc, ok := err.(errs.StatusCoder); ok {
w.WriteHeader(sc.StatusCode())
} else {
if sc, ok := cause.(errs.StatusCoder); ok {
w.WriteHeader(sc.StatusCode())
} else {
w.WriteHeader(http.StatusInternalServerError)
}
}
// Write errors in the response writer
if rl, ok := w.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"error": err,
})
if os.Getenv("STEPDEBUG") == "1" {
if e, ok := err.(errs.StackTracer); ok {
rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e),
})
} else if e, ok := cause.(errs.StackTracer); ok {
rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e),
})
}
}
}
if err := json.NewEncoder(w).Encode(err); err != nil {
LogError(w, err)
}
}

@ -0,0 +1,79 @@
// Package log implements API-related logging helpers.
package log
import (
"fmt"
"log"
"net/http"
"os"
"github.com/pkg/errors"
"github.com/smallstep/certificates/logging"
)
// StackTracedError is the set of errors implementing the StackTrace function.
//
// Errors implementing this interface have their stack traces logged when passed
// to the Error function of this package.
type StackTracedError interface {
error
StackTrace() errors.StackTrace
}
// Error adds to the response writer the given error if it implements
// logging.ResponseLogger. If it does not implement it, then writes the error
// using the log package.
func Error(rw http.ResponseWriter, err error) {
rl, ok := rw.(logging.ResponseLogger)
if !ok {
log.Println(err)
return
}
rl.WithFields(map[string]interface{}{
"error": err,
})
if os.Getenv("STEPDEBUG") != "1" {
return
}
e, ok := err.(StackTracedError)
if !ok {
e, ok = errors.Cause(err).(StackTracedError)
}
if ok {
rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e.StackTrace()),
})
}
}
// EnabledResponse log the response object if it implements the EnableLogger
// interface.
func EnabledResponse(rw http.ResponseWriter, v interface{}) {
type enableLogger interface {
ToLog() (interface{}, error)
}
if el, ok := v.(enableLogger); ok {
out, err := el.ToLog()
if err != nil {
Error(rw, err)
return
}
if rl, ok := rw.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"response": out,
})
} else {
log.Println(out)
}
}
}

@ -0,0 +1,44 @@
package log
import (
"errors"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"github.com/smallstep/certificates/logging"
)
func TestError(t *testing.T) {
theError := errors.New("the error")
type args struct {
rw http.ResponseWriter
err error
}
tests := []struct {
name string
args args
withFields bool
}{
{"normalLogger", args{httptest.NewRecorder(), theError}, false},
{"responseLogger", args{logging.NewResponseLogger(httptest.NewRecorder()), theError}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Error(tt.args.rw, tt.args.err)
if tt.withFields {
if rl, ok := tt.args.rw.(logging.ResponseLogger); ok {
fields := rl.Fields()
if !reflect.DeepEqual(fields["error"], theError) {
t.Errorf("ResponseLogger[\"error\"] = %s, wants %s", fields["error"], theError)
}
} else {
t.Error("ResponseWriter does not implement logging.ResponseLogger")
}
}
})
}
}

@ -0,0 +1,31 @@
// Package read implements request object readers.
package read
import (
"encoding/json"
"io"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"github.com/smallstep/certificates/errs"
)
// JSON reads JSON from the request body and stores it in the value
// pointed by v.
func JSON(r io.Reader, v interface{}) error {
if err := json.NewDecoder(r).Decode(v); err != nil {
return errs.BadRequestErr(err, "error decoding json")
}
return nil
}
// ProtoJSON reads JSON from the request body and stores it in the value
// pointed by v.
func ProtoJSON(r io.Reader, m proto.Message) error {
data, err := io.ReadAll(r)
if err != nil {
return errs.BadRequestErr(err, "error reading request body")
}
return protojson.Unmarshal(data, m)
}

@ -0,0 +1,46 @@
package read
import (
"io"
"reflect"
"strings"
"testing"
"github.com/smallstep/certificates/errs"
)
func TestJSON(t *testing.T) {
type args struct {
r io.Reader
v interface{}
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false},
{"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := JSON(tt.args.r, &tt.args.v)
if (err != nil) != tt.wantErr {
t.Errorf("JSON() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr {
e, ok := err.(*errs.Error)
if ok {
if code := e.StatusCode(); code != 400 {
t.Errorf("error.StatusCode() = %v, wants 400", code)
}
} else {
t.Errorf("error type = %T, wants *Error", err)
}
} else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) {
t.Errorf("JSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"})
}
})
}
}

@ -3,6 +3,8 @@ package api
import (
"net/http"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/errs"
)
@ -27,24 +29,24 @@ func (s *RekeyRequest) Validate() error {
// Rekey is similar to renew except that the certificate will be renewed with new key from csr.
func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest("missing client certificate"))
render.Error(w, errs.BadRequest("missing client certificate"))
return
}
var body RekeyRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return
}
if err := body.Validate(); err != nil {
WriteError(w, err)
render.Error(w, err)
return
}
certChain, err := h.Authority.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey)
if err != nil {
WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey"))
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey"))
return
}
certChainPEM := certChainToPEM(certChain)
@ -54,7 +56,7 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
}
LogCertificate(w, certChain[0])
JSONStatus(w, &SignResponse{
render.JSONStatus(w, &SignResponse{
ServerPEM: certChainPEM[0],
CaPEM: caPEM,
CertChainPEM: certChainPEM,

@ -0,0 +1,122 @@
// Package render implements functionality related to response rendering.
package render
import (
"bytes"
"encoding/json"
"net/http"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"github.com/smallstep/certificates/api/log"
)
// JSON is shorthand for JSONStatus(w, v, http.StatusOK).
func JSON(w http.ResponseWriter, v interface{}) {
JSONStatus(w, v, http.StatusOK)
}
// JSONStatus marshals v into w. It additionally sets the status code of
// w to the given one.
//
// JSONStatus sets the Content-Type of w to application/json unless one is
// specified.
func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(v); err != nil {
panic(err)
}
setContentTypeUnlessPresent(w, "application/json")
w.WriteHeader(status)
_, _ = b.WriteTo(w)
log.EnabledResponse(w, v)
}
// ProtoJSON is shorthand for ProtoJSONStatus(w, m, http.StatusOK).
func ProtoJSON(w http.ResponseWriter, m proto.Message) {
ProtoJSONStatus(w, m, http.StatusOK)
}
// ProtoJSONStatus writes the given value into the http.ResponseWriter and the
// given status is written as the status code of the response.
func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) {
b, err := protojson.Marshal(m)
if err != nil {
panic(err)
}
setContentTypeUnlessPresent(w, "application/json")
w.WriteHeader(status)
_, _ = w.Write(b)
}
func setContentTypeUnlessPresent(w http.ResponseWriter, contentType string) {
const header = "Content-Type"
h := w.Header()
if _, ok := h[header]; !ok {
h.Set(header, contentType)
}
}
// RenderableError is the set of errors that implement the basic Render method.
//
// Errors that implement this interface will use their own Render method when
// being rendered into responses.
type RenderableError interface {
error
Render(http.ResponseWriter)
}
// Error marshals the JSON representation of err to w. In case err implements
// RenderableError its own Render method will be called instead.
func Error(w http.ResponseWriter, err error) {
log.Error(w, err)
if e, ok := err.(RenderableError); ok {
e.Render(w)
return
}
JSONStatus(w, err, statusCodeFromError(err))
}
// StatusCodedError is the set of errors that implement the basic StatusCode
// function.
//
// Errors that implement this interface will use the code reported by StatusCode
// as the HTTP response code when being rendered by this package.
type StatusCodedError interface {
error
StatusCode() int
}
func statusCodeFromError(err error) (code int) {
code = http.StatusInternalServerError
type causer interface {
Cause() error
}
for err != nil {
if sc, ok := err.(StatusCodedError); ok {
code = sc.StatusCode()
break
}
cause, ok := err.(causer)
if !ok {
break
}
err = cause.Cause()
}
return
}

@ -0,0 +1,115 @@
package render
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/smallstep/certificates/logging"
)
func TestJSON(t *testing.T) {
rec := httptest.NewRecorder()
rw := logging.NewResponseLogger(rec)
JSON(rw, map[string]interface{}{"foo": "bar"})
assert.Equal(t, http.StatusOK, rec.Result().StatusCode)
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
assert.Equal(t, "{\"foo\":\"bar\"}\n", rec.Body.String())
assert.Empty(t, rw.Fields())
}
func TestJSONPanics(t *testing.T) {
assert.Panics(t, func() {
JSON(httptest.NewRecorder(), make(chan struct{}))
})
}
type renderableError struct {
Code int `json:"-"`
Message string `json:"message"`
}
func (err renderableError) Error() string {
return err.Message
}
func (err renderableError) Render(w http.ResponseWriter) {
w.Header().Set("Content-Type", "something/custom")
JSONStatus(w, err, err.Code)
}
type statusedError struct {
Contents string
}
func (err statusedError) Error() string { return err.Contents }
func (statusedError) StatusCode() int { return 432 }
func TestError(t *testing.T) {
cases := []struct {
err error
code int
body string
header string
}{
0: {
err: renderableError{532, "some string"},
code: 532,
body: "{\"message\":\"some string\"}\n",
header: "something/custom",
},
1: {
err: statusedError{"123"},
code: 432,
body: "{\"Contents\":\"123\"}\n",
header: "application/json",
},
}
for caseIndex := range cases {
kase := cases[caseIndex]
t.Run(strconv.Itoa(caseIndex), func(t *testing.T) {
rec := httptest.NewRecorder()
Error(rec, kase.err)
assert.Equal(t, kase.code, rec.Result().StatusCode)
assert.Equal(t, kase.body, rec.Body.String())
assert.Equal(t, kase.header, rec.Header().Get("Content-Type"))
})
}
}
type causedError struct {
cause error
}
func (err causedError) Error() string { return fmt.Sprintf("cause: %s", err.cause) }
func (err causedError) Cause() error { return err.cause }
func TestStatusCodeFromError(t *testing.T) {
cases := []struct {
err error
exp int
}{
0: {nil, http.StatusInternalServerError},
1: {io.EOF, http.StatusInternalServerError},
2: {statusedError{"123"}, 432},
3: {causedError{statusedError{"432"}}, 432},
}
for caseIndex, kase := range cases {
assert.Equal(t, kase.exp, statusCodeFromError(kase.err), "case: %d", caseIndex)
}
}

@ -1,22 +1,31 @@
package api
import (
"crypto/x509"
"net/http"
"strings"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/errs"
)
const (
authorizationHeader = "Authorization"
bearerScheme = "Bearer"
)
// Renew uses the information of certificate in the TLS connection to create a
// new one.
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest("missing client certificate"))
cert, err := h.getPeerCertificate(r)
if err != nil {
render.Error(w, err)
return
}
certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0])
certChain, err := h.Authority.Renew(cert)
if err != nil {
WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
return
}
certChainPEM := certChainToPEM(certChain)
@ -26,10 +35,22 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
}
LogCertificate(w, certChain[0])
JSONStatus(w, &SignResponse{
render.JSONStatus(w, &SignResponse{
ServerPEM: certChainPEM[0],
CaPEM: caPEM,
CertChainPEM: certChainPEM,
TLSOptions: h.Authority.GetTLSOptions(),
}, http.StatusCreated)
}
func (h *caHandler) getPeerCertificate(r *http.Request) (*x509.Certificate, error) {
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
return r.TLS.PeerCertificates[0], nil
}
if s := r.Header.Get(authorizationHeader); s != "" {
if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 {
return h.Authority.AuthorizeRenewToken(r.Context(), parts[1])
}
}
return nil, errs.BadRequest("missing client certificate")
}

@ -4,11 +4,14 @@ import (
"context"
"net/http"
"golang.org/x/crypto/ocsp"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging"
"golang.org/x/crypto/ocsp"
)
// RevokeResponse is the response object that returns the health of the server.
@ -48,13 +51,13 @@ func (r *RevokeRequest) Validate() (err error) {
// TODO: Add CRL and OCSP support.
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
var body RevokeRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return
}
if err := body.Validate(); err != nil {
WriteError(w, err)
render.Error(w, err)
return
}
@ -71,7 +74,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
if len(body.OTT) > 0 {
logOtt(w, body.OTT)
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
WriteError(w, errs.UnauthorizedErr(err))
render.Error(w, errs.UnauthorizedErr(err))
return
}
opts.OTT = body.OTT
@ -80,12 +83,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
// the client certificate Serial Number must match the serial number
// being revoked.
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest("missing ott or client certificate"))
render.Error(w, errs.BadRequest("missing ott or client certificate"))
return
}
opts.Crt = r.TLS.PeerCertificates[0]
if opts.Crt.SerialNumber.String() != opts.Serial {
WriteError(w, errs.BadRequest("serial number in client certificate different than body"))
render.Error(w, errs.BadRequest("serial number in client certificate different than body"))
return
}
// TODO: should probably be checking if the certificate was revoked here.
@ -96,12 +99,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
}
if err := h.Authority.Revoke(ctx, opts); err != nil {
WriteError(w, errs.ForbiddenErr(err, "error revoking certificate"))
render.Error(w, errs.ForbiddenErr(err, "error revoking certificate"))
return
}
logRevoke(w, opts)
JSON(w, &RevokeResponse{Status: "ok"})
render.JSON(w, &RevokeResponse{Status: "ok"})
}
func logRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) {

@ -13,6 +13,7 @@ import (
"testing"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"

@ -5,6 +5,8 @@ import (
"encoding/json"
"net/http"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
@ -49,14 +51,14 @@ type SignResponse struct {
// information in the certificate request.
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
var body SignRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return
}
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, err)
render.Error(w, err)
return
}
@ -68,13 +70,13 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
signOpts, err := h.Authority.AuthorizeSign(body.OTT)
if err != nil {
WriteError(w, errs.UnauthorizedErr(err))
render.Error(w, errs.UnauthorizedErr(err))
return
}
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error signing certificate"))
render.Error(w, errs.ForbiddenErr(err, "error signing certificate"))
return
}
certChainPEM := certChainToPEM(certChain)
@ -83,7 +85,7 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
caPEM = certChainPEM[1]
}
LogCertificate(w, certChain[0])
JSONStatus(w, &SignResponse{
render.JSONStatus(w, &SignResponse{
ServerPEM: certChainPEM[0],
CaPEM: caPEM,
CertChainPEM: certChainPEM,

@ -9,12 +9,15 @@ import (
"time"
"github.com/pkg/errors"
"golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/templates"
"golang.org/x/crypto/ssh"
)
// SSHAuthority is the interface implemented by a SSH CA authority.
@ -249,20 +252,20 @@ type SSHBastionResponse struct {
// the request.
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
var body SSHSignRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return
}
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, err)
render.Error(w, err)
return
}
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil {
WriteError(w, errs.BadRequestErr(err, "error parsing publicKey"))
render.Error(w, errs.BadRequestErr(err, "error parsing publicKey"))
return
}
@ -270,7 +273,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
if body.AddUserPublicKey != nil {
addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey)
if err != nil {
WriteError(w, errs.BadRequestErr(err, "error parsing addUserPublicKey"))
render.Error(w, errs.BadRequestErr(err, "error parsing addUserPublicKey"))
return
}
}
@ -287,13 +290,13 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil {
WriteError(w, errs.UnauthorizedErr(err))
render.Error(w, errs.UnauthorizedErr(err))
return
}
cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...)
if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
return
}
@ -301,7 +304,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil {
addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert)
if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
return
}
addUserCertificate = &SSHCertificate{addUserCert}
@ -314,7 +317,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil {
WriteError(w, errs.UnauthorizedErr(err))
render.Error(w, errs.UnauthorizedErr(err))
return
}
@ -326,13 +329,13 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...)
if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error signing identity certificate"))
render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate"))
return
}
identityCertificate = certChainToPEM(certChain)
}
JSONStatus(w, &SSHSignResponse{
render.JSONStatus(w, &SSHSignResponse{
Certificate: SSHCertificate{cert},
AddUserCertificate: addUserCertificate,
IdentityCertificate: identityCertificate,
@ -344,12 +347,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHRoots(r.Context())
if err != nil {
WriteError(w, errs.InternalServerErr(err))
render.Error(w, errs.InternalServerErr(err))
return
}
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
WriteError(w, errs.NotFound("no keys found"))
render.Error(w, errs.NotFound("no keys found"))
return
}
@ -361,7 +364,7 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k})
}
JSON(w, resp)
render.JSON(w, resp)
}
// SSHFederation is an HTTP handler that returns the federated SSH public keys
@ -369,12 +372,12 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHFederation(r.Context())
if err != nil {
WriteError(w, errs.InternalServerErr(err))
render.Error(w, errs.InternalServerErr(err))
return
}
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
WriteError(w, errs.NotFound("no keys found"))
render.Error(w, errs.NotFound("no keys found"))
return
}
@ -386,25 +389,25 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k})
}
JSON(w, resp)
render.JSON(w, resp)
}
// SSHConfig is an HTTP handler that returns rendered templates for ssh clients
// and servers.
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
var body SSHConfigRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return
}
if err := body.Validate(); err != nil {
WriteError(w, err)
render.Error(w, err)
return
}
ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data)
if err != nil {
WriteError(w, errs.InternalServerErr(err))
render.Error(w, errs.InternalServerErr(err))
return
}
@ -415,31 +418,31 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
case provisioner.SSHHostCert:
cfg.HostTemplates = ts
default:
WriteError(w, errs.InternalServer("it should hot get here"))
render.Error(w, errs.InternalServer("it should hot get here"))
return
}
JSON(w, cfg)
render.JSON(w, cfg)
}
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
var body SSHCheckPrincipalRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return
}
if err := body.Validate(); err != nil {
WriteError(w, err)
render.Error(w, err)
return
}
exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token)
if err != nil {
WriteError(w, errs.InternalServerErr(err))
render.Error(w, errs.InternalServerErr(err))
return
}
JSON(w, &SSHCheckPrincipalResponse{
render.JSON(w, &SSHCheckPrincipalResponse{
Exists: exists,
})
}
@ -453,10 +456,10 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
hosts, err := h.Authority.GetSSHHosts(r.Context(), cert)
if err != nil {
WriteError(w, errs.InternalServerErr(err))
render.Error(w, errs.InternalServerErr(err))
return
}
JSON(w, &SSHGetHostsResponse{
render.JSON(w, &SSHGetHostsResponse{
Hosts: hosts,
})
}
@ -464,22 +467,22 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
// SSHBastion provides returns the bastion configured if any.
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
var body SSHBastionRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return
}
if err := body.Validate(); err != nil {
WriteError(w, err)
render.Error(w, err)
return
}
bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname)
if err != nil {
WriteError(w, errs.InternalServerErr(err))
render.Error(w, errs.InternalServerErr(err))
return
}
JSON(w, &SSHBastionResponse{
render.JSON(w, &SSHBastionResponse{
Hostname: body.Hostname,
Bastion: bastion,
})

@ -4,9 +4,12 @@ import (
"net/http"
"time"
"golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh"
)
// SSHRekeyRequest is the request body of an SSH certificate request.
@ -38,37 +41,38 @@ type SSHRekeyResponse struct {
// the request.
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
var body SSHRekeyRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return
}
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, err)
render.Error(w, err)
return
}
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil {
WriteError(w, errs.BadRequestErr(err, "error parsing publicKey"))
render.Error(w, errs.BadRequestErr(err, "error parsing publicKey"))
return
}
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil {
WriteError(w, errs.UnauthorizedErr(err))
render.Error(w, errs.UnauthorizedErr(err))
return
}
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
if err != nil {
WriteError(w, errs.InternalServerErr(err))
render.Error(w, errs.InternalServerErr(err))
return
}
newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...)
if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error rekeying ssh certificate"))
render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate"))
return
}
@ -78,11 +82,11 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
return
}
JSONStatus(w, &SSHRekeyResponse{
render.JSONStatus(w, &SSHRekeyResponse{
Certificate: SSHCertificate{newCert},
IdentityCertificate: identity,
}, http.StatusCreated)

@ -6,6 +6,9 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
)
@ -36,31 +39,32 @@ type SSHRenewResponse struct {
// the request.
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
var body SSHRenewRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return
}
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, err)
render.Error(w, err)
return
}
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod)
_, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil {
WriteError(w, errs.UnauthorizedErr(err))
render.Error(w, errs.UnauthorizedErr(err))
return
}
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
if err != nil {
WriteError(w, errs.InternalServerErr(err))
render.Error(w, errs.InternalServerErr(err))
return
}
newCert, err := h.Authority.RenewSSH(ctx, oldCert)
if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error renewing ssh certificate"))
render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate"))
return
}
@ -70,11 +74,11 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
return
}
JSONStatus(w, &SSHSignResponse{
render.JSONStatus(w, &SSHSignResponse{
Certificate: SSHCertificate{newCert},
IdentityCertificate: identity,
}, http.StatusCreated)

@ -3,11 +3,14 @@ package api
import (
"net/http"
"golang.org/x/crypto/ocsp"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging"
"golang.org/x/crypto/ocsp"
)
// SSHRevokeResponse is the response object that returns the health of the server.
@ -47,13 +50,13 @@ func (r *SSHRevokeRequest) Validate() (err error) {
// NOTE: currently only Passive revocation is supported.
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
var body SSHRevokeRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return
}
if err := body.Validate(); err != nil {
WriteError(w, err)
render.Error(w, err)
return
}
@ -69,18 +72,18 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
// otherwise it is assumed that the certificate is revoking itself over mTLS.
logOtt(w, body.OTT)
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
WriteError(w, errs.UnauthorizedErr(err))
render.Error(w, errs.UnauthorizedErr(err))
return
}
opts.OTT = body.OTT
if err := h.Authority.Revoke(ctx, opts); err != nil {
WriteError(w, errs.ForbiddenErr(err, "error revoking ssh certificate"))
render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate"))
return
}
logSSHRevoke(w, opts)
JSON(w, &SSHRevokeResponse{Status: "ok"})
render.JSON(w, &SSHRevokeResponse{Status: "ok"})
}
func logSSHRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) {

@ -18,12 +18,13 @@ import (
"testing"
"time"
"golang.org/x/crypto/ssh"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/templates"
"golang.org/x/crypto/ssh"
)
var (

@ -1,109 +0,0 @@
package api
import (
"encoding/json"
"io"
"log"
"net/http"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
)
// EnableLogger is an interface that enables response logging for an object.
type EnableLogger interface {
ToLog() (interface{}, error)
}
// LogError adds to the response writer the given error if it implements
// logging.ResponseLogger. If it does not implement it, then writes the error
// using the log package.
func LogError(rw http.ResponseWriter, err error) {
if rl, ok := rw.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"error": err,
})
} else {
log.Println(err)
}
}
// LogEnabledResponse log the response object if it implements the EnableLogger
// interface.
func LogEnabledResponse(rw http.ResponseWriter, v interface{}) {
if el, ok := v.(EnableLogger); ok {
out, err := el.ToLog()
if err != nil {
LogError(rw, err)
return
}
if rl, ok := rw.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"response": out,
})
} else {
log.Println(out)
}
}
}
// JSON writes the passed value into the http.ResponseWriter.
func JSON(w http.ResponseWriter, v interface{}) {
JSONStatus(w, v, http.StatusOK)
}
// JSONStatus writes the given value into the http.ResponseWriter and the
// given status is written as the status code of the response.
func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
if err := json.NewEncoder(w).Encode(v); err != nil {
LogError(w, err)
return
}
LogEnabledResponse(w, v)
}
// ProtoJSON writes the passed value into the http.ResponseWriter.
func ProtoJSON(w http.ResponseWriter, m proto.Message) {
ProtoJSONStatus(w, m, http.StatusOK)
}
// ProtoJSONStatus writes the given value into the http.ResponseWriter and the
// given status is written as the status code of the response.
func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
b, err := protojson.Marshal(m)
if err != nil {
LogError(w, err)
return
}
if _, err := w.Write(b); err != nil {
LogError(w, err)
return
}
//LogEnabledResponse(w, v)
}
// ReadJSON reads JSON from the request body and stores it in the value
// pointed by v.
func ReadJSON(r io.Reader, v interface{}) error {
if err := json.NewDecoder(r).Decode(v); err != nil {
return errs.BadRequestErr(err, "error decoding json")
}
return nil
}
// ReadProtoJSON reads JSON from the request body and stores it in the value
// pointed by v.
func ReadProtoJSON(r io.Reader, m proto.Message) error {
data, err := io.ReadAll(r)
if err != nil {
return errs.BadRequestErr(err, "error reading request body")
}
return protojson.Unmarshal(data, m)
}

@ -1,125 +0,0 @@
package api
import (
"io"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging"
)
func TestLogError(t *testing.T) {
theError := errors.New("the error")
type args struct {
rw http.ResponseWriter
err error
}
tests := []struct {
name string
args args
withFields bool
}{
{"normalLogger", args{httptest.NewRecorder(), theError}, false},
{"responseLogger", args{logging.NewResponseLogger(httptest.NewRecorder()), theError}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
LogError(tt.args.rw, tt.args.err)
if tt.withFields {
if rl, ok := tt.args.rw.(logging.ResponseLogger); ok {
fields := rl.Fields()
if !reflect.DeepEqual(fields["error"], theError) {
t.Errorf("ResponseLogger[\"error\"] = %s, wants %s", fields["error"], theError)
}
} else {
t.Error("ResponseWriter does not implement logging.ResponseLogger")
}
}
})
}
}
func TestJSON(t *testing.T) {
type args struct {
rw http.ResponseWriter
v interface{}
}
tests := []struct {
name string
args args
ok bool
}{
{"ok", args{httptest.NewRecorder(), map[string]interface{}{"foo": "bar"}}, true},
{"fail", args{httptest.NewRecorder(), make(chan int)}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rw := logging.NewResponseLogger(tt.args.rw)
JSON(rw, tt.args.v)
rr, ok := tt.args.rw.(*httptest.ResponseRecorder)
if !ok {
t.Error("ResponseWriter does not implement *httptest.ResponseRecorder")
return
}
fields := rw.Fields()
if tt.ok {
if body := rr.Body.String(); body != "{\"foo\":\"bar\"}\n" {
t.Errorf(`Unexpected body = %v, want {"foo":"bar"}`, body)
}
if len(fields) != 0 {
t.Errorf("ResponseLogger fields = %v, wants 0 elements", fields)
}
} else {
if body := rr.Body.String(); body != "" {
t.Errorf("Unexpected body = %s, want empty string", body)
}
if len(fields) != 1 {
t.Errorf("ResponseLogger fields = %v, wants 1 element", fields)
}
}
})
}
}
func TestReadJSON(t *testing.T) {
type args struct {
r io.Reader
v interface{}
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false},
{"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ReadJSON(tt.args.r, &tt.args.v)
if (err != nil) != tt.wantErr {
t.Errorf("ReadJSON() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr {
e, ok := err.(*errs.Error)
if ok {
if code := e.StatusCode(); code != 400 {
t.Errorf("error.StatusCode() = %v, wants 400", code)
}
} else {
t.Errorf("error type = %T, wants *Error", err)
}
} else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) {
t.Errorf("ReadJSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"})
}
})
}
}

@ -6,10 +6,12 @@ import (
"net/http"
"github.com/go-chi/chi"
"github.com/smallstep/certificates/api"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/linkedca"
)
const (
@ -44,11 +46,11 @@ func (h *Handler) requireEABEnabled(next nextHTTP) nextHTTP {
provName := chi.URLParam(r, "provisionerName")
eabEnabled, prov, err := h.provisionerHasEABEnabled(ctx, provName)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
if !eabEnabled {
api.WriteError(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName()))
render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName()))
return
}
ctx = context.WithValue(ctx, provisionerContextKey, prov)
@ -101,15 +103,15 @@ func NewACMEAdminResponder() *ACMEAdminResponder {
// GetExternalAccountKeys writes the response for the EAB keys GET endpoint
func (h *ACMEAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) {
api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
}
// CreateExternalAccountKey writes the response for the EAB key POST endpoint
func (h *ACMEAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) {
api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
}
// DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint
func (h *ACMEAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) {
api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
}

@ -5,10 +5,14 @@ import (
"net/http"
"github.com/go-chi/chi"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/linkedca"
)
type adminAuthority interface {
@ -82,28 +86,28 @@ func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) {
adm, ok := h.auth.LoadAdminByID(id)
if !ok {
api.WriteError(w, admin.NewError(admin.ErrorNotFoundType,
render.Error(w, admin.NewError(admin.ErrorNotFoundType,
"admin %s not found", id))
return
}
api.ProtoJSON(w, adm)
render.ProtoJSON(w, adm)
}
// GetAdmins returns a segment of admins associated with the authority.
func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := api.ParseCursor(r)
if err != nil {
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err,
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
"error parsing cursor and limit from query params"))
return
}
admins, nextCursor, err := h.auth.GetAdmins(cursor, limit)
if err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error retrieving paginated admins"))
render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins"))
return
}
api.JSON(w, &GetAdminsResponse{
render.JSON(w, &GetAdminsResponse{
Admins: admins,
NextCursor: nextCursor,
})
@ -112,19 +116,19 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
// CreateAdmin creates a new admin.
func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
var body CreateAdminRequest
if err := api.ReadJSON(r.Body, &body); err != nil {
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
return
}
if err := body.Validate(); err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
p, err := h.auth.LoadProvisionerByName(body.Provisioner)
if err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner))
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner))
return
}
adm := &linkedca.Admin{
@ -134,11 +138,11 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
}
// Store to authority collection.
if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error storing admin"))
render.Error(w, admin.WrapErrorISE(err, "error storing admin"))
return
}
api.ProtoJSONStatus(w, adm, http.StatusCreated)
render.ProtoJSONStatus(w, adm, http.StatusCreated)
}
// DeleteAdmin deletes admin.
@ -146,23 +150,23 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
if err := h.auth.RemoveAdmin(r.Context(), id); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error deleting admin %s", id))
render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id))
return
}
api.JSON(w, &DeleteResponse{Status: "ok"})
render.JSON(w, &DeleteResponse{Status: "ok"})
}
// UpdateAdmin updates an existing admin.
func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
var body UpdateAdminRequest
if err := api.ReadJSON(r.Body, &body); err != nil {
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
return
}
if err := body.Validate(); err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
@ -170,9 +174,9 @@ func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type})
if err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error updating admin %s", id))
render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id))
return
}
api.ProtoJSON(w, adm)
render.ProtoJSON(w, adm)
}

@ -4,7 +4,7 @@ import (
"context"
"net/http"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/admin"
)
@ -15,7 +15,7 @@ type nextHTTP = func(http.ResponseWriter, *http.Request)
func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
if !h.auth.IsAdminAPIEnabled() {
api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType,
render.Error(w, admin.NewError(admin.ErrorNotImplementedType,
"administration API not enabled"))
return
}
@ -28,14 +28,14 @@ func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
tok := r.Header.Get("Authorization")
if tok == "" {
api.WriteError(w, admin.NewError(admin.ErrorUnauthorizedType,
render.Error(w, admin.NewError(admin.ErrorUnauthorizedType,
"missing authorization header token"))
return
}
adm, err := h.auth.AuthorizeAdminToken(r, tok)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}

@ -4,12 +4,16 @@ import (
"net/http"
"github.com/go-chi/chi"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"go.step.sm/linkedca"
)
// GetProvisionersResponse is the type for GET /admin/provisioners responses.
@ -31,39 +35,39 @@ func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) {
)
if len(id) > 0 {
if p, err = h.auth.LoadProvisionerByID(id); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
return
}
} else {
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
return
}
}
prov, err := h.adminDB.GetProvisioner(ctx, p.GetID())
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
api.ProtoJSON(w, prov)
render.ProtoJSON(w, prov)
}
// GetProvisioners returns the given segment of provisioners associated with the authority.
func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := api.ParseCursor(r)
if err != nil {
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err,
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
"error parsing cursor and limit from query params"))
return
}
p, next, err := h.auth.GetProvisioners(cursor, limit)
if err != nil {
api.WriteError(w, errs.InternalServerErr(err))
render.Error(w, errs.InternalServerErr(err))
return
}
api.JSON(w, &GetProvisionersResponse{
render.JSON(w, &GetProvisionersResponse{
Provisioners: p,
NextCursor: next,
})
@ -72,22 +76,22 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
// CreateProvisioner creates a new prov.
func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
var prov = new(linkedca.Provisioner)
if err := api.ReadProtoJSON(r.Body, prov); err != nil {
api.WriteError(w, err)
if err := read.ProtoJSON(r.Body, prov); err != nil {
render.Error(w, err)
return
}
// TODO: Validate inputs
if err := authority.ValidateClaims(prov.Claims); err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name))
render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name))
return
}
api.ProtoJSONStatus(w, prov, http.StatusCreated)
render.ProtoJSONStatus(w, prov, http.StatusCreated)
}
// DeleteProvisioner deletes a provisioner.
@ -101,75 +105,75 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
)
if len(id) > 0 {
if p, err = h.auth.LoadProvisionerByID(id); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
return
}
} else {
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
return
}
}
if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName()))
render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName()))
return
}
api.JSON(w, &DeleteResponse{Status: "ok"})
render.JSON(w, &DeleteResponse{Status: "ok"})
}
// UpdateProvisioner updates an existing prov.
func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
var nu = new(linkedca.Provisioner)
if err := api.ReadProtoJSON(r.Body, nu); err != nil {
api.WriteError(w, err)
if err := read.ProtoJSON(r.Body, nu); err != nil {
render.Error(w, err)
return
}
name := chi.URLParam(r, "name")
_old, err := h.auth.LoadProvisionerByName(name)
if err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name))
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name))
return
}
old, err := h.adminDB.GetProvisioner(r.Context(), _old.GetID())
if err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID()))
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID()))
return
}
if nu.Id != old.Id {
api.WriteError(w, admin.NewErrorISE("cannot change provisioner ID"))
render.Error(w, admin.NewErrorISE("cannot change provisioner ID"))
return
}
if nu.Type != old.Type {
api.WriteError(w, admin.NewErrorISE("cannot change provisioner type"))
render.Error(w, admin.NewErrorISE("cannot change provisioner type"))
return
}
if nu.AuthorityId != old.AuthorityId {
api.WriteError(w, admin.NewErrorISE("cannot change provisioner authorityID"))
render.Error(w, admin.NewErrorISE("cannot change provisioner authorityID"))
return
}
if !nu.CreatedAt.AsTime().Equal(old.CreatedAt.AsTime()) {
api.WriteError(w, admin.NewErrorISE("cannot change provisioner createdAt"))
render.Error(w, admin.NewErrorISE("cannot change provisioner createdAt"))
return
}
if !nu.DeletedAt.AsTime().Equal(old.DeletedAt.AsTime()) {
api.WriteError(w, admin.NewErrorISE("cannot change provisioner deletedAt"))
render.Error(w, admin.NewErrorISE("cannot change provisioner deletedAt"))
return
}
// TODO: Validate inputs
if err := authority.ValidateClaims(nu.Claims); err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
api.ProtoJSON(w, nu)
render.ProtoJSON(w, nu)
}

@ -3,13 +3,10 @@ package admin
import (
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/api/render"
)
// ProblemType is the type of the Admin problem.
@ -197,27 +194,9 @@ func (e *Error) ToLog() (interface{}, error) {
return string(b), nil
}
// WriteError writes to w a JSON representation of the given error.
func WriteError(w http.ResponseWriter, err *Error) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(err.StatusCode())
err.Message = err.Err.Error()
// Write errors in the response writer
if rl, ok := w.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"error": err.Err,
})
if os.Getenv("STEPDEBUG") == "1" {
if e, ok := err.Err.(errs.StackTracer); ok {
rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e),
})
}
}
}
// Render implements render.RenderableError for Error.
func (e *Error) Render(w http.ResponseWriter) {
e.Message = e.Err.Error()
if err := json.NewEncoder(w).Encode(err); err != nil {
log.Println(err)
}
render.JSONStatus(w, e, e.StatusCode())
}

@ -70,14 +70,24 @@ type Authority struct {
startTime time.Time
// Custom functions
sshBastionFunc func(ctx context.Context, user, hostname string) (*config.Bastion, error)
sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error)
sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error)
getIdentityFunc provisioner.GetIdentityFunc
sshBastionFunc func(ctx context.Context, user, hostname string) (*config.Bastion, error)
sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error)
sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error)
getIdentityFunc provisioner.GetIdentityFunc
authorizeRenewFunc provisioner.AuthorizeRenewFunc
authorizeSSHRenewFunc provisioner.AuthorizeSSHRenewFunc
adminMutex sync.RWMutex
}
type Info struct {
StartTime time.Time
RootX509Certs []*x509.Certificate
SSHCAUserPublicKey []byte
SSHCAHostPublicKey []byte
DNSNames []string
}
// New creates and initiates a new Authority type.
func New(cfg *config.Config, opts ...Option) (*Authority, error) {
err := cfg.Validate()
@ -175,7 +185,7 @@ func (a *Authority) reloadAdminResources(ctx context.Context) error {
// Create provisioner collection.
provClxn := provisioner.NewCollection(provisionerConfig.Audiences)
for _, p := range provList {
if err := p.Init(*provisionerConfig); err != nil {
if err := p.Init(provisionerConfig); err != nil {
return err
}
if err := provClxn.Store(p); err != nil {
@ -251,6 +261,21 @@ func (a *Authority) init() error {
}
}
// Initialize linkedca client if necessary. On a linked RA, the issuer
// configuration might come from majordomo.
var linkedcaClient *linkedCaClient
if a.config.AuthorityConfig.EnableAdmin && a.linkedCAToken != "" && a.adminDB == nil {
linkedcaClient, err = newLinkedCAClient(a.linkedCAToken)
if err != nil {
return err
}
// If authorityId is configured make sure it matches the one in the token
if id := a.config.AuthorityConfig.AuthorityID; id != "" && !strings.EqualFold(id, linkedcaClient.authorityID) {
return errors.New("error initializing linkedca: token authority and configured authority do not match")
}
linkedcaClient.Run()
}
// Initialize the X.509 CA Service if it has not been set in the options.
if a.x509CAService == nil {
var options casapi.Options
@ -258,6 +283,22 @@ func (a *Authority) init() error {
options = *a.config.AuthorityConfig.Options
}
// Configure linked RA
if linkedcaClient != nil && options.CertificateAuthority == "" {
conf, err := linkedcaClient.GetConfiguration(context.Background())
if err != nil {
return err
}
if conf.RaConfig != nil {
options.CertificateAuthority = conf.RaConfig.CaUrl
options.CertificateAuthorityFingerprint = conf.RaConfig.Fingerprint
options.CertificateIssuer = &casapi.CertificateIssuer{
Type: conf.RaConfig.Provisioner.Type.String(),
Provisioner: conf.RaConfig.Provisioner.Name,
}
}
}
// Set the issuer password if passed in the flags.
if options.CertificateIssuer != nil && a.issuerPassword != nil {
options.CertificateIssuer.Password = string(a.issuerPassword)
@ -292,8 +333,6 @@ func (a *Authority) init() error {
return err
}
a.rootX509Certs = append(a.rootX509Certs, resp.RootCertificate)
sum := sha256.Sum256(resp.RootCertificate.Raw)
log.Printf("Using root fingerprint '%s'", hex.EncodeToString(sum[:]))
}
}
@ -479,24 +518,13 @@ func (a *Authority) init() error {
// Initialize step-ca Admin Database if it's not already initialized using
// WithAdminDB.
if a.adminDB == nil {
if a.linkedCAToken == "" {
// Check if AuthConfig already exists
a.adminDB, err = adminDBNosql.New(a.db.(nosql.DB), admin.DefaultAuthorityID)
if err != nil {
return err
}
if linkedcaClient != nil {
a.adminDB = linkedcaClient
} else {
// Use the linkedca client as the admindb.
client, err := newLinkedCAClient(a.linkedCAToken)
a.adminDB, err = adminDBNosql.New(a.db.(nosql.DB), admin.DefaultAuthorityID)
if err != nil {
return err
}
// If authorityId is configured make sure it matches the one in the token
if id := a.config.AuthorityConfig.AuthorityID; id != "" && !strings.EqualFold(id, client.authorityID) {
return errors.New("error initializing linkedca: token authority and configured authority do not match")
}
client.Run()
a.adminDB = client
}
}
@ -559,6 +587,21 @@ func (a *Authority) GetAdminDatabase() admin.DB {
return a.adminDB
}
func (a *Authority) GetInfo() Info {
ai := Info{
StartTime: a.startTime,
RootX509Certs: a.rootX509Certs,
DNSNames: a.config.DNSNames,
}
if a.sshCAUserCertSignKey != nil {
ai.SSHCAUserPublicKey = ssh.MarshalAuthorizedKey(a.sshCAUserCertSignKey.PublicKey())
}
if a.sshCAHostCertSignKey != nil {
ai.SSHCAHostPublicKey = ssh.MarshalAuthorizedKey(a.sshCAHostCertSignKey.PublicKey())
}
return ai
}
// IsAdminAPIEnabled returns a boolean indicating whether the Admin API has
// been enabled.
func (a *Authority) IsAdminAPIEnabled() bool {

@ -6,6 +6,7 @@ import (
"crypto/x509"
"encoding/hex"
"net/http"
"net/url"
"strconv"
"strings"
"time"
@ -276,6 +277,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
serial := cert.SerialNumber.String()
var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)}
isRevoked, err := a.IsRevoked(serial)
if err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
@ -283,7 +285,6 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
if isRevoked {
return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...)
}
p, ok := a.provisioners.LoadByCertificate(cert)
if !ok {
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
@ -371,3 +372,80 @@ func (a *Authority) authorizeSSHRevoke(ctx context.Context, token string) error
}
return nil
}
// AuthorizeRenewToken validates the renew token and returns the leaf
// certificate in the x5cInsecure header.
func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) {
var claims jose.Claims
jwt, chain, err := jose.ParseX5cInsecure(ott, a.rootX509Certs)
if err != nil {
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token"))
}
leaf := chain[0][0]
if err := jwt.Claims(leaf.PublicKey, &claims); err != nil {
return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token"))
}
p, ok := a.provisioners.LoadByCertificate(leaf)
if !ok {
return nil, errs.Unauthorized("error validating renew token: cannot get provisioner from certificate")
}
if err := a.UseToken(ott, p); err != nil {
return nil, err
}
if err := claims.ValidateWithLeeway(jose.Expected{
Issuer: p.GetName(),
Subject: leaf.Subject.CommonName,
Time: time.Now().UTC(),
}, time.Minute); err != nil {
switch err {
case jose.ErrInvalidIssuer:
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: invalid issuer claim (iss)"))
case jose.ErrInvalidSubject:
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: invalid subject claim (sub)"))
case jose.ErrNotValidYet:
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token not valid yet (nbf)"))
case jose.ErrExpired:
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token is expired (exp)"))
case jose.ErrIssuedInTheFuture:
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token issued in the future (iat)"))
default:
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token"))
}
}
audiences := a.config.GetAudiences().Renew
if !matchesAudience(claims.Audience, audiences) {
return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token: invalid audience claim (aud)"))
}
return leaf, nil
}
// matchesAudience returns true if A and B share at least one element.
func matchesAudience(as, bs []string) bool {
if len(bs) == 0 || len(as) == 0 {
return false
}
for _, b := range bs {
for _, a := range as {
if b == a || stripPort(a) == stripPort(b) {
return true
}
}
}
return false
}
// stripPort attempts to strip the port from the given url. If parsing the url
// produces errors it will just return the passed argument.
func stripPort(rawurl string) string {
u, err := url.Parse(rawurl)
if err != nil {
return rawurl
}
u.Host = u.Hostname()
return u.String()
}

@ -3,24 +3,32 @@ package authority
import (
"context"
"crypto"
"crypto/ed25519"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/base64"
"errors"
"fmt"
"net/http"
"reflect"
"strconv"
"testing"
"time"
"github.com/pkg/errors"
"golang.org/x/crypto/ssh"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil"
"go.step.sm/crypto/randutil"
"go.step.sm/crypto/x509util"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil"
"go.step.sm/crypto/randutil"
"golang.org/x/crypto/ssh"
)
var testAudiences = provisioner.Audiences{
@ -305,8 +313,8 @@ func TestAuthority_authorizeToken(t *testing.T) {
p, err := tc.auth.authorizeToken(context.Background(), tc.token)
if err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
@ -391,8 +399,8 @@ func TestAuthority_authorizeRevoke(t *testing.T) {
if err := tc.auth.authorizeRevoke(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
@ -476,14 +484,14 @@ func TestAuthority_authorizeSign(t *testing.T) {
got, err := tc.auth.authorizeSign(context.Background(), tc.token)
if err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
assert.Len(t, 7, got)
assert.Len(t, 8, got)
}
}
})
@ -735,8 +743,8 @@ func TestAuthority_Authorize(t *testing.T) {
if err != nil {
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
assert.Nil(t, got)
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -753,6 +761,7 @@ func TestAuthority_Authorize(t *testing.T) {
func TestAuthority_authorizeRenew(t *testing.T) {
fooCrt, err := pemutil.ReadCertificate("testdata/certs/foo.crt")
fooCrt.NotAfter = time.Now().Add(time.Hour)
assert.FatalError(t, err)
renewDisabledCrt, err := pemutil.ReadCertificate("testdata/certs/renew-disabled.crt")
@ -822,7 +831,7 @@ func TestAuthority_authorizeRenew(t *testing.T) {
return &authorizeTest{
auth: a,
cert: renewDisabledCrt,
err: errors.New("authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'renew_disabled'"),
err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'renew_disabled'"),
code: http.StatusUnauthorized,
}
},
@ -847,7 +856,7 @@ func TestAuthority_authorizeRenew(t *testing.T) {
err := tc.auth.authorizeRenew(tc.cert)
if err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -909,6 +918,7 @@ func generateSSHToken(sub, iss, aud string, iat time.Time, sshOpts *provisioner.
}
func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) {
now := time.Now()
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0)
if err != nil {
return nil, nil, err
@ -917,6 +927,12 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate,
if err != nil {
return nil, nil, err
}
if cert.ValidAfter == 0 {
cert.ValidAfter = uint64(now.Unix())
}
if cert.ValidBefore == 0 {
cert.ValidBefore = uint64(now.Add(time.Hour).Unix())
}
if err := cert.SignCert(rand.Reader, signer); err != nil {
return nil, nil, err
}
@ -988,8 +1004,8 @@ func TestAuthority_authorizeSSHSign(t *testing.T) {
got, err := tc.auth.authorizeSSHSign(context.Background(), tc.token)
if err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
@ -1003,6 +1019,23 @@ func TestAuthority_authorizeSSHSign(t *testing.T) {
}
func TestAuthority_authorizeSSHRenew(t *testing.T) {
now := time.Now().UTC()
sshpop := func(a *Authority) (*ssh.Certificate, string) {
p, ok := a.provisioners.Load("sshpop/sshpop")
assert.Fatal(t, ok, "sshpop provisioner not found in test authority")
key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key")
assert.FatalError(t, err)
signer, ok := key.(crypto.Signer)
assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer")
sshSigner, err := ssh.NewSignerFromSigner(signer)
assert.FatalError(t, err)
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner)
assert.FatalError(t, err)
token, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0]+"#sshpop/sshpop", []string{"foo.smallstep.com"}, now, jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return cert, token
}
a := testAuthority(t)
jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
@ -1012,8 +1045,6 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) {
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
assert.FatalError(t, err)
now := time.Now().UTC()
validIssuer := "step-cli"
type authorizeTest struct {
@ -1050,27 +1081,34 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) {
code: http.StatusUnauthorized,
}
},
"fail/WithAuthorizeSSHRenewFunc": func(t *testing.T) *authorizeTest {
aa := testAuthority(t, WithAuthorizeSSHRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error {
return errs.Forbidden("forbidden")
}))
_, token := sshpop(aa)
return &authorizeTest{
auth: aa,
token: token,
err: errors.New("authority.authorizeSSHRenew: forbidden"),
code: http.StatusForbidden,
}
},
"ok": func(t *testing.T) *authorizeTest {
key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key")
assert.FatalError(t, err)
signer, ok := key.(crypto.Signer)
assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer")
sshSigner, err := ssh.NewSignerFromSigner(signer)
assert.FatalError(t, err)
cert, _jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner)
assert.FatalError(t, err)
p, ok := a.provisioners.Load("sshpop/sshpop")
assert.Fatal(t, ok, "sshpop provisioner not found in test authority")
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0]+"#sshpop/sshpop",
[]string{"foo.smallstep.com"}, now, _jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
cert, token := sshpop(a)
return &authorizeTest{
auth: a,
token: tok,
token: token,
cert: cert,
}
},
"ok/WithAuthorizeSSHRenewFunc": func(t *testing.T) *authorizeTest {
aa := testAuthority(t, WithAuthorizeSSHRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error {
return nil
}))
cert, token := sshpop(aa)
return &authorizeTest{
auth: aa,
token: token,
cert: cert,
}
},
@ -1083,8 +1121,8 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) {
got, err := tc.auth.authorizeSSHRenew(context.Background(), tc.token)
if err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
@ -1183,8 +1221,8 @@ func TestAuthority_authorizeSSHRevoke(t *testing.T) {
if err := tc.auth.authorizeSSHRevoke(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
@ -1276,8 +1314,8 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) {
cert, signOpts, err := tc.auth.authorizeSSHRekey(context.Background(), tc.token)
if err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
@ -1290,3 +1328,283 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) {
})
}
}
func TestAuthority_AuthorizeRenewToken(t *testing.T) {
ctx := context.Background()
type stepProvisionerASN1 struct {
Type int
Name []byte
CredentialID []byte
KeyValuePairs []string `asn1:"optional,omitempty"`
}
_, signer, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
csr, err := x509util.CreateCertificateRequest("test.example.com", []string{"test.example.com"}, signer)
if err != nil {
t.Fatal(err)
}
_, otherSigner, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
generateX5cToken := func(a *Authority, key crypto.Signer, claims jose.Claims, opts ...provisioner.SignOption) (string, *x509.Certificate) {
chain, err := a.Sign(csr, provisioner.SignOptions{}, opts...)
if err != nil {
t.Fatal(err)
}
var x5c []string
for _, c := range chain {
x5c = append(x5c, base64.StdEncoding.EncodeToString(c.Raw))
}
so := new(jose.SignerOptions)
so.WithType("JWT")
so.WithHeader("x5cInsecure", x5c)
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.EdDSA, Key: key}, so)
if err != nil {
t.Fatal(err)
}
s, err := jose.Signed(sig).Claims(claims).CompactSerialize()
if err != nil {
t.Fatal(err)
}
return s, chain[0]
}
now := time.Now()
a1 := testAuthority(t)
t1, c1 := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
t2, c2 := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
IssuedAt: jose.NewNumericDate(now),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now.Add(-time.Hour)
cert.NotAfter = now.Add(-time.Minute)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("foobar"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badProvisioner, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("foobar"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badIssuer, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "bad-issuer",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badSubject, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "bad-subject",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badNotBefore, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now.Add(5 * time.Minute)),
Expiry: jose.NewNumericDate(now.Add(10 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badExpiry, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now.Add(-5 * time.Minute)),
Expiry: jose.NewNumericDate(now.Add(-time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badIssuedAt, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
IssuedAt: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badAudience, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
type args struct {
ctx context.Context
ott string
}
tests := []struct {
name string
authority *Authority
args args
want *x509.Certificate
wantErr bool
}{
{"ok", a1, args{ctx, t1}, c1, false},
{"ok expired cert", a1, args{ctx, t2}, c2, false},
{"fail token", a1, args{ctx, "not.a.token"}, nil, true},
{"fail token reuse", a1, args{ctx, t1}, nil, true},
{"fail token signature", a1, args{ctx, badSigner}, nil, true},
{"fail token provisioner", a1, args{ctx, badProvisioner}, nil, true},
{"fail token iss", a1, args{ctx, badIssuer}, nil, true},
{"fail token sub", a1, args{ctx, badSubject}, nil, true},
{"fail token iat", a1, args{ctx, badNotBefore}, nil, true},
{"fail token iat", a1, args{ctx, badExpiry}, nil, true},
{"fail token iat", a1, args{ctx, badIssuedAt}, nil, true},
{"fail token aud", a1, args{ctx, badAudience}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.authority.AuthorizeRenewToken(tt.args.ctx, tt.args.ott)
if (err != nil) != tt.wantErr {
t.Errorf("Authority.AuthorizeRenewToken() error = %+v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Authority.AuthorizeRenewToken() = %v, want %v", got, tt.want)
}
})
}
}

@ -26,23 +26,27 @@ var (
DefaultBackdate = time.Minute
// DefaultDisableRenewal disables renewals per provisioner.
DefaultDisableRenewal = false
// DefaultAllowRenewAfterExpiry allows renewals even if the certificate is
// expired.
DefaultAllowRenewAfterExpiry = false
// DefaultEnableSSHCA enable SSH CA features per provisioner or globally
// for all provisioners.
DefaultEnableSSHCA = false
// GlobalProvisionerClaims default claims for the Authority. Can be overridden
// by provisioner specific claims.
GlobalProvisionerClaims = provisioner.Claims{
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DisableRenewal: &DefaultDisableRenewal,
MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs
MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour},
MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs
MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
EnableSSHCA: &DefaultEnableSSHCA,
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs
MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour},
MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs
MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
EnableSSHCA: &DefaultEnableSSHCA,
DisableRenewal: &DefaultDisableRenewal,
AllowRenewAfterExpiry: &DefaultAllowRenewAfterExpiry,
}
)
@ -273,28 +277,32 @@ func (c *Config) GetAudiences() provisioner.Audiences {
}
for _, name := range c.DNSNames {
hostname := toHostname(name)
audiences.Sign = append(audiences.Sign,
fmt.Sprintf("https://%s/1.0/sign", toHostname(name)),
fmt.Sprintf("https://%s/sign", toHostname(name)),
fmt.Sprintf("https://%s/1.0/ssh/sign", toHostname(name)),
fmt.Sprintf("https://%s/ssh/sign", toHostname(name)))
fmt.Sprintf("https://%s/1.0/sign", hostname),
fmt.Sprintf("https://%s/sign", hostname),
fmt.Sprintf("https://%s/1.0/ssh/sign", hostname),
fmt.Sprintf("https://%s/ssh/sign", hostname))
audiences.Renew = append(audiences.Renew,
fmt.Sprintf("https://%s/1.0/renew", hostname),
fmt.Sprintf("https://%s/renew", hostname))
audiences.Revoke = append(audiences.Revoke,
fmt.Sprintf("https://%s/1.0/revoke", toHostname(name)),
fmt.Sprintf("https://%s/revoke", toHostname(name)))
fmt.Sprintf("https://%s/1.0/revoke", hostname),
fmt.Sprintf("https://%s/revoke", hostname))
audiences.SSHSign = append(audiences.SSHSign,
fmt.Sprintf("https://%s/1.0/ssh/sign", toHostname(name)),
fmt.Sprintf("https://%s/ssh/sign", toHostname(name)),
fmt.Sprintf("https://%s/1.0/sign", toHostname(name)),
fmt.Sprintf("https://%s/sign", toHostname(name)))
fmt.Sprintf("https://%s/1.0/ssh/sign", hostname),
fmt.Sprintf("https://%s/ssh/sign", hostname),
fmt.Sprintf("https://%s/1.0/sign", hostname),
fmt.Sprintf("https://%s/sign", hostname))
audiences.SSHRevoke = append(audiences.SSHRevoke,
fmt.Sprintf("https://%s/1.0/ssh/revoke", toHostname(name)),
fmt.Sprintf("https://%s/ssh/revoke", toHostname(name)))
fmt.Sprintf("https://%s/1.0/ssh/revoke", hostname),
fmt.Sprintf("https://%s/ssh/revoke", hostname))
audiences.SSHRenew = append(audiences.SSHRenew,
fmt.Sprintf("https://%s/1.0/ssh/renew", toHostname(name)),
fmt.Sprintf("https://%s/ssh/renew", toHostname(name)))
fmt.Sprintf("https://%s/1.0/ssh/renew", hostname),
fmt.Sprintf("https://%s/ssh/renew", hostname))
audiences.SSHRekey = append(audiences.SSHRekey,
fmt.Sprintf("https://%s/1.0/ssh/rekey", toHostname(name)),
fmt.Sprintf("https://%s/ssh/rekey", toHostname(name)))
fmt.Sprintf("https://%s/1.0/ssh/rekey", hostname),
fmt.Sprintf("https://%s/ssh/rekey", hostname))
}
return audiences

@ -15,6 +15,7 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
@ -151,13 +152,21 @@ func (c *linkedCaClient) GetProvisioner(ctx context.Context, id string) (*linked
}
func (c *linkedCaClient) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) {
resp, err := c.GetConfiguration(ctx)
if err != nil {
return nil, err
}
return resp.Provisioners, nil
}
func (c *linkedCaClient) GetConfiguration(ctx context.Context) (*linkedca.ConfigurationResponse, error) {
resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{
AuthorityId: c.authorityID,
})
if err != nil {
return nil, errors.Wrap(err, "error getting provisioners")
return nil, errors.Wrap(err, "error getting configuration")
}
return resp.Provisioners, nil
return resp, nil
}
func (c *linkedCaClient) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error {
@ -204,11 +213,9 @@ func (c *linkedCaClient) GetAdmin(ctx context.Context, id string) (*linkedca.Adm
}
func (c *linkedCaClient) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) {
resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{
AuthorityId: c.authorityID,
})
resp, err := c.GetConfiguration(ctx)
if err != nil {
return nil, errors.Wrap(err, "error getting admins")
return nil, err
}
return resp.Admins, nil
}
@ -228,12 +235,13 @@ func (c *linkedCaClient) DeleteAdmin(ctx context.Context, id string) error {
return errors.Wrap(err, "error deleting admin")
}
func (c *linkedCaClient) StoreCertificateChain(fullchain ...*x509.Certificate) error {
func (c *linkedCaClient) StoreCertificateChain(prov provisioner.Interface, fullchain ...*x509.Certificate) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
_, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{
PemCertificate: serializeCertificateChain(fullchain[0]),
PemCertificateChain: serializeCertificateChain(fullchain[1:]...),
Provisioner: createProvisionerIdentity(prov),
})
return errors.Wrap(err, "error posting certificate")
}
@ -310,6 +318,17 @@ func (c *linkedCaClient) IsSSHRevoked(serial string) (bool, error) {
return resp.Status != linkedca.RevocationStatus_ACTIVE, nil
}
func createProvisionerIdentity(prov provisioner.Interface) *linkedca.ProvisionerIdentity {
if prov == nil {
return nil
}
return &linkedca.ProvisionerIdentity{
Id: prov.GetID(),
Type: linkedca.Provisioner_Type(prov.GetType()),
Name: prov.GetName(),
}
}
func serializeCertificate(crt *x509.Certificate) string {
if crt == nil {
return ""

@ -92,6 +92,24 @@ func WithGetIdentityFunc(fn func(ctx context.Context, p provisioner.Interface, e
}
}
// WithAuthorizeRenewFunc sets a custom function that authorizes the renewal of
// an X.509 certificate.
func WithAuthorizeRenewFunc(fn func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error) Option {
return func(a *Authority) error {
a.authorizeRenewFunc = fn
return nil
}
}
// WithAuthorizeSSHRenewFunc sets a custom function that authorizes the renewal
// of a SSH certificate.
func WithAuthorizeSSHRenewFunc(fn func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error) Option {
return func(a *Authority) error {
a.authorizeSSHRenewFunc = fn
return nil
}
}
// WithSSHBastionFunc sets a custom function to get the bastion for a
// given user-host pair.
func WithSSHBastionFunc(fn func(ctx context.Context, user, host string) (*config.Bastion, error)) Option {
@ -145,6 +163,22 @@ func WithX509Signer(crt *x509.Certificate, s crypto.Signer) Option {
}
}
// WithX509SignerFunc defines the function used to get the chain of certificates
// and signer used when we sign X.509 certificates.
func WithX509SignerFunc(fn func() ([]*x509.Certificate, crypto.Signer, error)) Option {
return func(a *Authority) error {
srv, err := cas.New(context.Background(), casapi.Options{
Type: casapi.SoftCAS,
CertificateSigner: fn,
})
if err != nil {
return err
}
a.x509CAService = srv
return nil
}
}
// WithSSHUserSigner defines the signer used to sign SSH user certificates.
func WithSSHUserSigner(s crypto.Signer) Option {
return func(a *Authority) error {

@ -6,7 +6,6 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
)
// ACME is the acme provisioner type, an entity that can authorize the ACME
@ -24,7 +23,7 @@ type ACME struct {
RequireEAB bool `json:"requireEAB,omitempty"`
Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"`
claimer *Claimer
ctl *Controller
}
// GetID returns the provisioner unique identifier.
@ -69,7 +68,7 @@ func (p *ACME) GetOptions() *Options {
// DefaultTLSCertDuration returns the default TLS cert duration enforced by
// the provisioner.
func (p *ACME) DefaultTLSCertDuration() time.Duration {
return p.claimer.DefaultTLSCertDuration()
return p.ctl.Claimer.DefaultTLSCertDuration()
}
// Init initializes and validates the fields of a JWK type.
@ -81,12 +80,8 @@ func (p *ACME) Init(config Config) (err error) {
return errors.New("provisioner name cannot be empty")
}
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
return err
p.ctl, err = NewController(p, p.Claims, config)
return
}
// AuthorizeSign does not do any validation, because all validation is handled
@ -94,13 +89,14 @@ func (p *ACME) Init(config Config) (err error) {
// on the resulting certificate.
func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
return []SignOption{
p,
// modifiers / withOptions
newProvisionerExtensionOption(TypeACME, p.Name, ""),
newForceCNOption(p.ForceCN),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators
defaultPublicKeyValidator{},
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
}, nil
}
@ -118,8 +114,5 @@ func (p *ACME) AuthorizeRevoke(ctx context.Context, token string) error {
// revocation status. Just confirms that the provisioner that created the
// certificate was configured to allow renewals.
func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errs.Unauthorized("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName())
}
return nil
return p.ctl.AuthorizeRenew(ctx, cert)
}

@ -3,13 +3,14 @@ package provisioner
import (
"context"
"crypto/x509"
"errors"
"fmt"
"net/http"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/api/render"
)
func TestACME_Getters(t *testing.T) {
@ -91,6 +92,7 @@ func TestACME_Init(t *testing.T) {
}
func TestACME_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
type test struct {
p *ACME
cert *x509.Certificate
@ -104,21 +106,27 @@ func TestACME_AuthorizeRenew(t *testing.T) {
// disable renewal
disable := true
p.Claims = &Claims{DisableRenewal: &disable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
return test{
p: p,
cert: &x509.Certificate{},
p: p,
cert: &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
},
code: http.StatusUnauthorized,
err: errors.Errorf("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName()),
err: fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()),
}
},
"ok": func(t *testing.T) test {
p, err := generateACME()
assert.FatalError(t, err)
return test{
p: p,
cert: &x509.Certificate{},
p: p,
cert: &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
},
}
},
}
@ -126,8 +134,8 @@ func TestACME_AuthorizeRenew(t *testing.T) {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -161,31 +169,32 @@ func TestACME_AuthorizeSign(t *testing.T) {
tc := tt(t)
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) && assert.NotNil(t, opts) {
assert.Len(t, 5, opts)
assert.Len(t, 6, opts)
for _, o := range opts {
switch v := o.(type) {
case *ACME:
case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeACME))
assert.Equals(t, v.Type, TypeACME)
assert.Equals(t, v.Name, tc.p.GetName())
assert.Equals(t, v.CredentialID, "")
assert.Len(t, 0, v.KeyValuePairs)
case *forceCNOption:
assert.Equals(t, v.ForceCN, tc.p.ForceCN)
case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration())
assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration())
case defaultPublicKeyValidator:
case *validityValidator:
assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration())
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}
}
}

@ -264,9 +264,8 @@ type AWS struct {
IIDRoots string `json:"iidRoots,omitempty"`
Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"`
claimer *Claimer
config *awsConfig
audiences Audiences
ctl *Controller
}
// GetID returns the provisioner unique identifier.
@ -400,15 +399,11 @@ func (p *AWS) Init(config Config) (err error) {
case p.InstanceAge.Value() < 0:
return errors.New("provisioner instanceAge cannot be negative")
}
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// Add default config
if p.config, err = newAWSConfig(p.IIDRoots); err != nil {
return err
}
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
// validate IMDS versions
if len(p.IMDSVersions) == 0 {
@ -425,7 +420,9 @@ func (p *AWS) Init(config Config) (err error) {
}
}
return nil
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.ctl, err = NewController(p, p.Claims, config)
return
}
// AuthorizeSign validates the given token and returns the sign options that
@ -470,14 +467,15 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
}
return append(so,
p,
templateOptions,
// modifiers / withOptions
newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID, "InstanceID", doc.InstanceID),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators
defaultPublicKeyValidator{},
commonNameValidator(payload.Claims.Subject),
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
), nil
}
@ -486,10 +484,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// revocation status. Just confirms that the provisioner that created the
// certificate was configured to allow renewals.
func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errs.Unauthorized("aws.AuthorizeRenew; renew is disabled for aws provisioner '%s'", p.GetName())
}
return nil
return p.ctl.AuthorizeRenew(ctx, cert)
}
// assertConfig initializes the config if it has not been initialized
@ -664,7 +659,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
}
// validate audiences with the defaults
if !matchesAudience(payload.Audience, p.audiences.Sign) {
if !matchesAudience(payload.Audience, p.ctl.Audiences.Sign) {
return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid audience claim (aud)")
}
@ -704,7 +699,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() {
if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner '%s'", p.GetName())
}
claims, err := p.authorizeToken(token)
@ -752,11 +747,11 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validate user SignSSHOptions.
sshCertOptionsValidator(defaults),
// Set the validity bounds if not set.
&sshDefaultDuration{p.claimer},
&sshDefaultDuration{p.ctl.Claimer},
// Validate public key
&sshDefaultPublicKeyValidator{},
// Validate the validity period.
&sshCertValidityValidator{p.claimer},
&sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate
&sshCertDefaultValidator{},
), nil

@ -9,6 +9,7 @@ import (
"crypto/x509"
"encoding/hex"
"encoding/pem"
"errors"
"fmt"
"net"
"net/http"
@ -17,10 +18,10 @@ import (
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
)
func TestAWS_Getters(t *testing.T) {
@ -521,8 +522,8 @@ func TestAWS_authorizeToken(t *testing.T) {
tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
@ -641,11 +642,11 @@ func TestAWS_AuthorizeSign(t *testing.T) {
code int
wantErr bool
}{
{"ok", p1, args{t1, "foo.local"}, 6, http.StatusOK, false},
{"ok", p2, args{t2, "instance-id"}, 10, http.StatusOK, false},
{"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 10, http.StatusOK, false},
{"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 10, http.StatusOK, false},
{"ok", p1, args{t4, "instance-id"}, 6, http.StatusOK, false},
{"ok", p1, args{t1, "foo.local"}, 7, http.StatusOK, false},
{"ok", p2, args{t2, "instance-id"}, 11, http.StatusOK, false},
{"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 11, http.StatusOK, false},
{"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 11, http.StatusOK, false},
{"ok", p1, args{t4, "instance-id"}, 7, http.StatusOK, false},
{"fail account", p3, args{token: t3}, 0, http.StatusUnauthorized, true},
{"fail token", p1, args{token: "token"}, 0, http.StatusUnauthorized, true},
{"fail subject", p1, args{token: failSubject}, 0, http.StatusUnauthorized, true},
@ -668,27 +669,28 @@ func TestAWS_AuthorizeSign(t *testing.T) {
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return
case err != nil:
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
default:
assert.Len(t, tt.wantLen, got)
for _, o := range got {
switch v := o.(type) {
case *AWS:
case certificateOptionsFunc:
case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeAWS))
assert.Equals(t, v.Type, TypeAWS)
assert.Equals(t, v.Name, tt.aws.GetName())
assert.Equals(t, v.CredentialID, tt.aws.Accounts[0])
assert.Len(t, 2, v.KeyValuePairs)
case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tt.aws.claimer.DefaultTLSCertDuration())
assert.Equals(t, time.Duration(v), tt.aws.ctl.Claimer.DefaultTLSCertDuration())
case commonNameValidator:
assert.Equals(t, string(v), tt.args.cn)
case defaultPublicKeyValidator:
case *validityValidator:
assert.Equals(t, v.min, tt.aws.claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.aws.claimer.MaxTLSCertDuration())
assert.Equals(t, v.min, tt.aws.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.aws.ctl.Claimer.MaxTLSCertDuration())
case ipAddressesValidator:
assert.Equals(t, []net.IP(v), []net.IP{net.ParseIP("127.0.0.1")})
case emailAddressesValidator:
@ -698,7 +700,7 @@ func TestAWS_AuthorizeSign(t *testing.T) {
case dnsNamesValidator:
assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"})
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}
}
}
@ -726,7 +728,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
// disable sshCA
disable := false
p3.Claims = &Claims{EnableSSHCA: &disable}
p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
t1, err := p1.GetIdentityToken("127.0.0.1", "https://ca.smallstep.com")
@ -747,7 +749,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err)
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedHostOptions := &SignSSHOptions{
CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
@ -802,8 +804,8 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
return
}
if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got)
} else if assert.NotNil(t, got) {
@ -824,6 +826,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
}
func TestAWS_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
p1, err := generateAWS()
assert.FatalError(t, err)
p2, err := generateAWS()
@ -832,7 +835,7 @@ func TestAWS_AuthorizeRenew(t *testing.T) {
// disable renewal
disable := true
p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
type args struct {
@ -845,16 +848,22 @@ func TestAWS_AuthorizeRenew(t *testing.T) {
code int
wantErr bool
}{
{"ok", p1, args{nil}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
{"ok", p1, args{&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusUnauthorized, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.aws.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("AWS.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
}
})

@ -30,7 +30,7 @@ const azureDefaultAudience = "https://management.azure.com/"
// azureXMSMirIDRegExp is the regular expression used to parse the xms_mirid claim.
// Using case insensitive as resourceGroups appears as resourcegroups.
var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.Compute/virtualMachines/([^/]+)$`)
var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.(Compute/virtualMachines|ManagedIdentity/userAssignedIdentities)/([^/]+)$`)
type azureConfig struct {
oidcDiscoveryURL string
@ -89,15 +89,17 @@ type Azure struct {
Name string `json:"name"`
TenantID string `json:"tenantID"`
ResourceGroups []string `json:"resourceGroups"`
SubscriptionIDs []string `json:"subscriptionIDs"`
ObjectIDs []string `json:"objectIDs"`
Audience string `json:"audience,omitempty"`
DisableCustomSANs bool `json:"disableCustomSANs"`
DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"`
Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"`
claimer *Claimer
config *azureConfig
oidcConfig openIDConfiguration
keyStore *keyStore
ctl *Controller
}
// GetID returns the provisioner unique identifier.
@ -201,37 +203,34 @@ func (p *Azure) Init(config Config) (err error) {
case p.Audience == "": // use default audience
p.Audience = azureDefaultAudience
}
// Initialize config
p.assertConfig()
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// Decode and validate openid-configuration endpoint
if err := getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil {
return err
if err = getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil {
return
}
if err := p.oidcConfig.Validate(); err != nil {
return errors.Wrapf(err, "error parsing %s", p.config.oidcDiscoveryURL)
}
// Get JWK key set
if p.keyStore, err = newKeyStore(p.oidcConfig.JWKSetURI); err != nil {
return err
return
}
return nil
p.ctl, err = NewController(p, p.Claims, config)
return
}
// authorizeToken returns the claims, name, group, error.
func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, error) {
// authorizeToken returns the claims, name, group, subscription, identityObjectID, error.
func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, string, string, error) {
jwt, err := jose.ParseSigned(token)
if err != nil {
return nil, "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; error parsing azure token")
return nil, "", "", "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; error parsing azure token")
}
if len(jwt.Headers) == 0 {
return nil, "", "", errs.Unauthorized("azure.authorizeToken; azure token missing header")
return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; azure token missing header")
}
var found bool
@ -244,7 +243,7 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, err
}
}
if !found {
return nil, "", "", errs.Unauthorized("azure.authorizeToken; cannot validate azure token")
return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; cannot validate azure token")
}
if err := claims.ValidateWithLeeway(jose.Expected{
@ -252,26 +251,30 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, err
Issuer: p.oidcConfig.Issuer,
Time: time.Now(),
}, 1*time.Minute); err != nil {
return nil, "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; failed to validate azure token payload")
return nil, "", "", "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; failed to validate azure token payload")
}
// Validate TenantID
if claims.TenantID != p.TenantID {
return nil, "", "", errs.Unauthorized("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)")
return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)")
}
re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID)
if len(re) != 4 {
return nil, "", "", errs.Unauthorized("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID)
if len(re) != 5 {
return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID)
}
group, name := re[2], re[3]
return &claims, name, group, nil
var subscription, group, name string
identityObjectID := claims.ObjectID
subscription, group, name = re[1], re[2], re[4]
return &claims, name, group, subscription, identityObjectID, nil
}
// AuthorizeSign validates the given token and returns the sign options that
// will be used on certificate creation.
func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
_, name, group, err := p.authorizeToken(token)
_, name, group, subscription, identityObjectID, err := p.authorizeToken(token)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSign")
}
@ -290,6 +293,34 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
}
}
// Filter by subscription id
if len(p.SubscriptionIDs) > 0 {
var found bool
for _, s := range p.SubscriptionIDs {
if s == subscription {
found = true
break
}
}
if !found {
return nil, errs.Unauthorized("azure.AuthorizeSign; azure token validation failed - invalid subscription id")
}
}
// Filter by Azure AD identity object id
if len(p.ObjectIDs) > 0 {
var found bool
for _, i := range p.ObjectIDs {
if i == identityObjectID {
found = true
break
}
}
if !found {
return nil, errs.Unauthorized("azure.AuthorizeSign; azure token validation failed - invalid identity object id")
}
}
// Template options
data := x509util.NewTemplateData()
data.SetCommonName(name)
@ -321,13 +352,14 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
}
return append(so,
p,
templateOptions,
// modifiers / withOptions
newProvisionerExtensionOption(TypeAzure, p.Name, p.TenantID),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators
defaultPublicKeyValidator{},
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
), nil
}
@ -336,19 +368,16 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
// revocation status. Just confirms that the provisioner that created the
// certificate was configured to allow renewals.
func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errs.Unauthorized("azure.AuthorizeRenew; renew is disabled for azure provisioner '%s'", p.GetName())
}
return nil
return p.ctl.AuthorizeRenew(ctx, cert)
}
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() {
if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner '%s'", p.GetName())
}
_, name, _, err := p.authorizeToken(token)
_, name, _, _, _, err := p.authorizeToken(token)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSSHSign")
}
@ -389,11 +418,11 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
// Validate user SignSSHOptions.
sshCertOptionsValidator(defaults),
// Set the validity bounds if not set.
&sshDefaultDuration{p.claimer},
&sshDefaultDuration{p.ctl.Claimer},
// Validate public key
&sshDefaultPublicKeyValidator{},
// Validate the validity period.
&sshCertValidityValidator{p.claimer},
&sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate
&sshCertDefaultValidator{},
), nil

@ -8,6 +8,7 @@ import (
"crypto/sha256"
"crypto/x509"
"encoding/hex"
"errors"
"fmt"
"net/http"
"net/http/httptest"
@ -15,10 +16,10 @@ import (
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
)
func TestAzure_Getters(t *testing.T) {
@ -95,7 +96,7 @@ func TestAzure_GetIdentityToken(t *testing.T) {
assert.FatalError(t, err)
t1, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
@ -237,7 +238,7 @@ func TestAzure_authorizeToken(t *testing.T) {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), jwk)
assert.FatalError(t, err)
return test{
@ -252,7 +253,7 @@ func TestAzure_authorizeToken(t *testing.T) {
assert.FatalError(t, err)
defer srv.Close()
tok, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience,
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
@ -267,7 +268,7 @@ func TestAzure_authorizeToken(t *testing.T) {
assert.FatalError(t, err)
defer srv.Close()
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
"foo", "subscriptionID", "resourceGroup", "virtualMachine",
"foo", "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
@ -321,7 +322,7 @@ func TestAzure_authorizeToken(t *testing.T) {
assert.FatalError(t, err)
defer srv.Close()
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
return test{
@ -333,10 +334,10 @@ func TestAzure_authorizeToken(t *testing.T) {
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if claims, name, group, err := tc.p.authorizeToken(tc.token); err != nil {
if claims, name, group, subscriptionID, objectID, err := tc.p.authorizeToken(tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
@ -348,6 +349,8 @@ func TestAzure_authorizeToken(t *testing.T) {
assert.Equals(t, name, "virtualMachine")
assert.Equals(t, group, "resourceGroup")
assert.Equals(t, subscriptionID, "subscriptionID")
assert.Equals(t, objectID, "the-oid")
}
}
})
@ -382,6 +385,38 @@ func TestAzure_AuthorizeSign(t *testing.T) {
p4.oidcConfig = p1.oidcConfig
p4.keyStore = p1.keyStore
p5, err := generateAzure()
assert.FatalError(t, err)
p5.TenantID = p1.TenantID
p5.SubscriptionIDs = []string{"subscriptionID"}
p5.config = p1.config
p5.oidcConfig = p1.oidcConfig
p5.keyStore = p1.keyStore
p6, err := generateAzure()
assert.FatalError(t, err)
p6.TenantID = p1.TenantID
p6.SubscriptionIDs = []string{"foobarzar"}
p6.config = p1.config
p6.oidcConfig = p1.oidcConfig
p6.keyStore = p1.keyStore
p7, err := generateAzure()
assert.FatalError(t, err)
p7.TenantID = p1.TenantID
p7.ObjectIDs = []string{"the-oid"}
p7.config = p1.config
p7.oidcConfig = p1.oidcConfig
p7.keyStore = p1.keyStore
p8, err := generateAzure()
assert.FatalError(t, err)
p8.TenantID = p1.TenantID
p8.ObjectIDs = []string{"foobarzar"}
p8.config = p1.config
p8.oidcConfig = p1.oidcConfig
p8.keyStore = p1.keyStore
badKey, err := generateJSONWebKey()
assert.FatalError(t, err)
@ -393,30 +428,38 @@ func TestAzure_AuthorizeSign(t *testing.T) {
assert.FatalError(t, err)
t4, err := p4.GetIdentityToken("subject", "caURL")
assert.FatalError(t, err)
t5, err := p5.GetIdentityToken("subject", "caURL")
assert.FatalError(t, err)
t6, err := p6.GetIdentityToken("subject", "caURL")
assert.FatalError(t, err)
t7, err := p6.GetIdentityToken("subject", "caURL")
assert.FatalError(t, err)
t8, err := p6.GetIdentityToken("subject", "caURL")
assert.FatalError(t, err)
t11, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
failIssuer, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
failAudience, err := generateAzureToken("subject", p1.oidcConfig.Issuer, "bad-audience",
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
failExp, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now().Add(-360*time.Second), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
failNbf, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now().Add(360*time.Second), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
failKey, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), badKey)
assert.FatalError(t, err)
@ -431,11 +474,15 @@ func TestAzure_AuthorizeSign(t *testing.T) {
code int
wantErr bool
}{
{"ok", p1, args{t1}, 5, http.StatusOK, false},
{"ok", p2, args{t2}, 10, http.StatusOK, false},
{"ok", p1, args{t11}, 5, http.StatusOK, false},
{"ok", p1, args{t1}, 6, http.StatusOK, false},
{"ok", p2, args{t2}, 11, http.StatusOK, false},
{"ok", p1, args{t11}, 6, http.StatusOK, false},
{"ok", p5, args{t5}, 6, http.StatusOK, false},
{"ok", p7, args{t7}, 6, http.StatusOK, false},
{"fail tenant", p3, args{t3}, 0, http.StatusUnauthorized, true},
{"fail resource group", p4, args{t4}, 0, http.StatusUnauthorized, true},
{"fail subscription", p6, args{t6}, 0, http.StatusUnauthorized, true},
{"fail object id", p8, args{t8}, 0, http.StatusUnauthorized, true},
{"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true},
{"fail issuer", p1, args{failIssuer}, 0, http.StatusUnauthorized, true},
{"fail audience", p1, args{failAudience}, 0, http.StatusUnauthorized, true},
@ -451,27 +498,28 @@ func TestAzure_AuthorizeSign(t *testing.T) {
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return
case err != nil:
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
default:
assert.Len(t, tt.wantLen, got)
for _, o := range got {
switch v := o.(type) {
case *Azure:
case certificateOptionsFunc:
case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeAzure))
assert.Equals(t, v.Type, TypeAzure)
assert.Equals(t, v.Name, tt.azure.GetName())
assert.Equals(t, v.CredentialID, tt.azure.TenantID)
assert.Len(t, 0, v.KeyValuePairs)
case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tt.azure.claimer.DefaultTLSCertDuration())
assert.Equals(t, time.Duration(v), tt.azure.ctl.Claimer.DefaultTLSCertDuration())
case commonNameValidator:
assert.Equals(t, string(v), "virtualMachine")
case defaultPublicKeyValidator:
case *validityValidator:
assert.Equals(t, v.min, tt.azure.claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.azure.claimer.MaxTLSCertDuration())
assert.Equals(t, v.min, tt.azure.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.azure.ctl.Claimer.MaxTLSCertDuration())
case ipAddressesValidator:
assert.Equals(t, v, nil)
case emailAddressesValidator:
@ -481,7 +529,7 @@ func TestAzure_AuthorizeSign(t *testing.T) {
case dnsNamesValidator:
assert.Equals(t, []string(v), []string{"virtualMachine"})
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}
}
}
@ -490,6 +538,7 @@ func TestAzure_AuthorizeSign(t *testing.T) {
}
func TestAzure_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
p1, err := generateAzure()
assert.FatalError(t, err)
p2, err := generateAzure()
@ -498,7 +547,7 @@ func TestAzure_AuthorizeRenew(t *testing.T) {
// disable renewal
disable := true
p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
type args struct {
@ -511,16 +560,22 @@ func TestAzure_AuthorizeRenew(t *testing.T) {
code int
wantErr bool
}{
{"ok", p1, args{nil}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
{"ok", p1, args{&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusUnauthorized, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.azure.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("Azure.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
}
})
@ -549,7 +604,7 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) {
// disable sshCA
disable := false
p3.Claims = &Claims{EnableSSHCA: &disable}
p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
t1, err := p1.GetIdentityToken("subject", "caURL")
@ -570,7 +625,7 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) {
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err)
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedHostOptions := &SignSSHOptions{
CertType: "host", Principals: []string{"virtualMachine"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
@ -615,8 +670,8 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) {
return
}
if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got)
} else if assert.NotNil(t, got) {

@ -10,10 +10,10 @@ import (
// Claims so that individual provisioners can override global claims.
type Claims struct {
// TLS CA properties
MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"`
MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"`
DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"`
DisableRenewal *bool `json:"disableRenewal,omitempty"`
MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"`
MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"`
DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"`
// SSH CA properties
MinUserSSHDur *Duration `json:"minUserSSHCertDuration,omitempty"`
MaxUserSSHDur *Duration `json:"maxUserSSHCertDuration,omitempty"`
@ -22,6 +22,10 @@ type Claims struct {
MaxHostSSHDur *Duration `json:"maxHostSSHCertDuration,omitempty"`
DefaultHostSSHDur *Duration `json:"defaultHostSSHCertDuration,omitempty"`
EnableSSHCA *bool `json:"enableSSHCA,omitempty"`
// Renewal properties
DisableRenewal *bool `json:"disableRenewal,omitempty"`
AllowRenewAfterExpiry *bool `json:"allowRenewAfterExpiry,omitempty"`
}
// Claimer is the type that controls claims. It provides an interface around the
@ -40,19 +44,22 @@ func NewClaimer(claims *Claims, global Claims) (*Claimer, error) {
// Claims returns the merge of the inner and global claims.
func (c *Claimer) Claims() Claims {
disableRenewal := c.IsDisableRenewal()
allowRenewAfterExpiry := c.AllowRenewAfterExpiry()
enableSSHCA := c.IsSSHCAEnabled()
return Claims{
MinTLSDur: &Duration{c.MinTLSCertDuration()},
MaxTLSDur: &Duration{c.MaxTLSCertDuration()},
DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()},
DisableRenewal: &disableRenewal,
MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()},
MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()},
DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()},
MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()},
MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()},
DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()},
EnableSSHCA: &enableSSHCA,
MinTLSDur: &Duration{c.MinTLSCertDuration()},
MaxTLSDur: &Duration{c.MaxTLSCertDuration()},
DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()},
MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()},
MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()},
DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()},
MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()},
MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()},
DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()},
EnableSSHCA: &enableSSHCA,
DisableRenewal: &disableRenewal,
AllowRenewAfterExpiry: &allowRenewAfterExpiry,
}
}
@ -102,6 +109,16 @@ func (c *Claimer) IsDisableRenewal() bool {
return *c.claims.DisableRenewal
}
// AllowRenewAfterExpiry returns if the renewal flow is authorized if the
// certificate is expired. If the property is not set within the provisioner
// then the global value from the authority configuration will be used.
func (c *Claimer) AllowRenewAfterExpiry() bool {
if c.claims == nil || c.claims.AllowRenewAfterExpiry == nil {
return *c.global.AllowRenewAfterExpiry
}
return *c.claims.AllowRenewAfterExpiry
}
// DefaultSSHCertDuration returns the default SSH certificate duration for the
// given certificate type.
func (c *Claimer) DefaultSSHCertDuration(certType uint32) (time.Duration, error) {

@ -152,8 +152,8 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims)
// proper id to load the provisioner.
func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool) {
for _, e := range cert.Extensions {
if e.Id.Equal(stepOIDProvisioner) {
var provisioner stepProvisionerASN1
if e.Id.Equal(StepOIDProvisioner) {
var provisioner extensionASN1
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
return nil, false
}

@ -147,6 +147,17 @@ func TestCollection_LoadByToken(t *testing.T) {
}
func TestCollection_LoadByCertificate(t *testing.T) {
mustExtension := func(typ Type, name, credentialID string) pkix.Extension {
e := Extension{
Type: typ, Name: name, CredentialID: credentialID,
}
ext, err := e.ToExtension()
if err != nil {
t.Fatal(err)
}
return ext
}
p1, err := generateJWK()
assert.FatalError(t, err)
p2, err := generateOIDC()
@ -159,30 +170,21 @@ func TestCollection_LoadByCertificate(t *testing.T) {
byName.Store(p2.GetName(), p2)
byName.Store(p3.GetName(), p3)
ok1Ext, err := createProvisionerExtension(1, p1.Name, p1.Key.KeyID)
assert.FatalError(t, err)
ok2Ext, err := createProvisionerExtension(2, p2.Name, p2.ClientID)
assert.FatalError(t, err)
ok3Ext, err := createProvisionerExtension(int(TypeACME), p3.Name, "")
assert.FatalError(t, err)
notFoundExt, err := createProvisionerExtension(1, "foo", "bar")
assert.FatalError(t, err)
ok1Cert := &x509.Certificate{
Extensions: []pkix.Extension{ok1Ext},
Extensions: []pkix.Extension{mustExtension(1, p1.Name, p1.Key.KeyID)},
}
ok2Cert := &x509.Certificate{
Extensions: []pkix.Extension{ok2Ext},
Extensions: []pkix.Extension{mustExtension(2, p2.Name, p2.ClientID)},
}
ok3Cert := &x509.Certificate{
Extensions: []pkix.Extension{ok3Ext},
Extensions: []pkix.Extension{mustExtension(TypeACME, p3.Name, "")},
}
notFoundCert := &x509.Certificate{
Extensions: []pkix.Extension{notFoundExt},
Extensions: []pkix.Extension{mustExtension(1, "foo", "bar")},
}
badCert := &x509.Certificate{
Extensions: []pkix.Extension{
{Id: stepOIDProvisioner, Critical: false, Value: []byte("foobar")},
{Id: StepOIDProvisioner, Critical: false, Value: []byte("foobar")},
},
}

@ -0,0 +1,194 @@
package provisioner
import (
"context"
"crypto/x509"
"regexp"
"strings"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh"
)
// Controller wraps a provisioner with other attributes useful in callback
// functions.
type Controller struct {
Interface
Audiences *Audiences
Claimer *Claimer
IdentityFunc GetIdentityFunc
AuthorizeRenewFunc AuthorizeRenewFunc
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
}
// NewController initializes a new provisioner controller.
func NewController(p Interface, claims *Claims, config Config) (*Controller, error) {
claimer, err := NewClaimer(claims, config.Claims)
if err != nil {
return nil, err
}
return &Controller{
Interface: p,
Audiences: &config.Audiences,
Claimer: claimer,
IdentityFunc: config.GetIdentityFunc,
AuthorizeRenewFunc: config.AuthorizeRenewFunc,
AuthorizeSSHRenewFunc: config.AuthorizeSSHRenewFunc,
}, nil
}
// GetIdentity returns the identity for a given email.
func (c *Controller) GetIdentity(ctx context.Context, email string) (*Identity, error) {
if c.IdentityFunc != nil {
return c.IdentityFunc(ctx, c.Interface, email)
}
return DefaultIdentityFunc(ctx, c.Interface, email)
}
// AuthorizeRenew returns nil if the given cert can be renewed, returns an error
// otherwise.
func (c *Controller) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if c.AuthorizeRenewFunc != nil {
return c.AuthorizeRenewFunc(ctx, c, cert)
}
return DefaultAuthorizeRenew(ctx, c, cert)
}
// AuthorizeSSHRenew returns nil if the given cert can be renewed, returns an
// error otherwise.
func (c *Controller) AuthorizeSSHRenew(ctx context.Context, cert *ssh.Certificate) error {
if c.AuthorizeSSHRenewFunc != nil {
return c.AuthorizeSSHRenewFunc(ctx, c, cert)
}
return DefaultAuthorizeSSHRenew(ctx, c, cert)
}
// Identity is the type representing an externally supplied identity that is used
// by provisioners to populate certificate fields.
type Identity struct {
Usernames []string `json:"usernames"`
Permissions `json:"permissions"`
}
// GetIdentityFunc is a function that returns an identity.
type GetIdentityFunc func(ctx context.Context, p Interface, email string) (*Identity, error)
// AuthorizeRenewFunc is a function that returns nil if the renewal of a
// certificate is enabled.
type AuthorizeRenewFunc func(ctx context.Context, p *Controller, cert *x509.Certificate) error
// AuthorizeSSHRenewFunc is a function that returns nil if the renewal of the
// given SSH certificate is enabled.
type AuthorizeSSHRenewFunc func(ctx context.Context, p *Controller, cert *ssh.Certificate) error
// DefaultIdentityFunc return a default identity depending on the provisioner
// type. For OIDC email is always present and the usernames might
// contain empty strings.
func DefaultIdentityFunc(ctx context.Context, p Interface, email string) (*Identity, error) {
switch k := p.(type) {
case *OIDC:
// OIDC principals would be:
// ~~1. Preferred usernames.~~ Note: Under discussion, currently disabled
// 2. Sanitized local.
// 3. Raw local (if different).
// 4. Email address.
name := SanitizeSSHUserPrincipal(email)
if !sshUserRegex.MatchString(name) {
return nil, errors.Errorf("invalid principal '%s' from email '%s'", name, email)
}
usernames := []string{name}
if i := strings.LastIndex(email, "@"); i >= 0 {
usernames = append(usernames, email[:i])
}
usernames = append(usernames, email)
return &Identity{
Usernames: SanitizeStringSlices(usernames),
}, nil
default:
return nil, errors.Errorf("provisioner type '%T' not supported by identity function", k)
}
}
// DefaultAuthorizeRenew is the default implementation of AuthorizeRenew. It
// will return an error if the provisioner has the renewal disabled, if the
// certificate is not yet valid or if the certificate is expired and renew after
// expiry is disabled.
func DefaultAuthorizeRenew(ctx context.Context, p *Controller, cert *x509.Certificate) error {
if p.Claimer.IsDisableRenewal() {
return errs.Unauthorized("renew is disabled for provisioner '%s'", p.GetName())
}
now := time.Now().Truncate(time.Second)
if now.Before(cert.NotBefore) {
return errs.Unauthorized("certificate is not yet valid" + " " + now.UTC().Format(time.RFC3339Nano) + " vs " + cert.NotBefore.Format(time.RFC3339Nano))
}
if now.After(cert.NotAfter) && !p.Claimer.AllowRenewAfterExpiry() {
return errs.Unauthorized("certificate has expired")
}
return nil
}
// DefaultAuthorizeSSHRenew is the default implementation of AuthorizeSSHRenew. It
// will return an error if the provisioner has the renewal disabled, if the
// certificate is not yet valid or if the certificate is expired and renew after
// expiry is disabled.
func DefaultAuthorizeSSHRenew(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
if p.Claimer.IsDisableRenewal() {
return errs.Unauthorized("renew is disabled for provisioner '%s'", p.GetName())
}
unixNow := time.Now().Unix()
if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) {
return errs.Unauthorized("certificate is not yet valid")
}
if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) && !p.Claimer.AllowRenewAfterExpiry() {
return errs.Unauthorized("certificate has expired")
}
return nil
}
var sshUserRegex = regexp.MustCompile("^[a-z][-a-z0-9_]*$")
// SanitizeStringSlices removes duplicated an empty strings.
func SanitizeStringSlices(original []string) []string {
output := []string{}
seen := make(map[string]struct{})
for _, entry := range original {
if entry == "" {
continue
}
if _, value := seen[entry]; !value {
seen[entry] = struct{}{}
output = append(output, entry)
}
}
return output
}
// SanitizeSSHUserPrincipal grabs an email or a string with the format
// local@domain and returns a sanitized version of the local, valid to be used
// as a user name. If the email starts with a letter between a and z, the
// resulting string will match the regular expression `^[a-z][-a-z0-9_]*$`.
func SanitizeSSHUserPrincipal(email string) string {
if i := strings.LastIndex(email, "@"); i >= 0 {
email = email[:i]
}
return strings.Map(func(r rune) rune {
switch {
case r >= 'a' && r <= 'z':
return r
case r >= '0' && r <= '9':
return r
case r == '-':
return '-'
case r == '.': // drop dots
return -1
default:
return '_'
}
}, strings.ToLower(email))
}

@ -0,0 +1,391 @@
package provisioner
import (
"context"
"crypto/x509"
"fmt"
"reflect"
"testing"
"time"
"golang.org/x/crypto/ssh"
)
var trueValue = true
func mustClaimer(t *testing.T, claims *Claims, global Claims) *Claimer {
t.Helper()
c, err := NewClaimer(claims, global)
if err != nil {
t.Fatal(err)
}
return c
}
func mustDuration(t *testing.T, s string) *Duration {
t.Helper()
d, err := NewDuration(s)
if err != nil {
t.Fatal(err)
}
return d
}
func TestNewController(t *testing.T) {
type args struct {
p Interface
claims *Claims
config Config
}
tests := []struct {
name string
args args
want *Controller
wantErr bool
}{
{"ok", args{&JWK{}, nil, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}}, &Controller{
Interface: &JWK{},
Audiences: &testAudiences,
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
}, false},
{"ok with claims", args{&JWK{}, &Claims{
DisableRenewal: &defaultDisableRenewal,
}, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}}, &Controller{
Interface: &JWK{},
Audiences: &testAudiences,
Claimer: mustClaimer(t, &Claims{
DisableRenewal: &defaultDisableRenewal,
}, globalProvisionerClaims),
}, false},
{"fail claimer", args{&JWK{}, &Claims{
MinTLSDur: mustDuration(t, "24h"),
MaxTLSDur: mustDuration(t, "2h"),
}, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewController(tt.args.p, tt.args.claims, tt.args.config)
if (err != nil) != tt.wantErr {
t.Errorf("NewController() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewController() = %v, want %v", got, tt.want)
}
})
}
}
func TestController_GetIdentity(t *testing.T) {
ctx := context.Background()
type fields struct {
Interface Interface
IdentityFunc GetIdentityFunc
}
type args struct {
ctx context.Context
email string
}
tests := []struct {
name string
fields fields
args args
want *Identity
wantErr bool
}{
{"ok", fields{&OIDC{}, nil}, args{ctx, "jane@doe.org"}, &Identity{
Usernames: []string{"jane", "jane@doe.org"},
}, false},
{"ok custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) {
return &Identity{Usernames: []string{"jane"}}, nil
}}, args{ctx, "jane@doe.org"}, &Identity{
Usernames: []string{"jane"},
}, false},
{"fail provisioner", fields{&JWK{}, nil}, args{ctx, "jane@doe.org"}, nil, true},
{"fail custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) {
return nil, fmt.Errorf("an error")
}}, args{ctx, "jane@doe.org"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Controller{
Interface: tt.fields.Interface,
IdentityFunc: tt.fields.IdentityFunc,
}
got, err := c.GetIdentity(tt.args.ctx, tt.args.email)
if (err != nil) != tt.wantErr {
t.Errorf("Controller.GetIdentity() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Controller.GetIdentity() = %v, want %v", got, tt.want)
}
})
}
}
func TestController_AuthorizeRenew(t *testing.T) {
ctx := context.Background()
now := time.Now().Truncate(time.Second)
type fields struct {
Interface Interface
Claimer *Claimer
AuthorizeRenewFunc AuthorizeRenewFunc
}
type args struct {
ctx context.Context
cert *x509.Certificate
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, false},
{"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
return nil
}}, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, false},
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
return nil
}}, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, false},
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute),
}}, false},
{"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, true},
{"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
NotBefore: now.Add(time.Hour),
NotAfter: now.Add(2 * time.Hour),
}}, true},
{"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute),
}}, true},
{"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
return fmt.Errorf("an error")
}}, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Controller{
Interface: tt.fields.Interface,
Claimer: tt.fields.Claimer,
AuthorizeRenewFunc: tt.fields.AuthorizeRenewFunc,
}
if err := c.AuthorizeRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("Controller.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestController_AuthorizeSSHRenew(t *testing.T) {
ctx := context.Background()
now := time.Now()
type fields struct {
Interface Interface
Claimer *Claimer
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
}
type args struct {
ctx context.Context
cert *ssh.Certificate
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, false},
{"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
return nil
}}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, false},
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
return nil
}}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, false},
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
}}, false},
{"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, true},
{"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Add(time.Hour).Unix()),
ValidBefore: uint64(now.Add(2 * time.Hour).Unix()),
}}, true},
{"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
}}, true},
{"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
return fmt.Errorf("an error")
}}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Controller{
Interface: tt.fields.Interface,
Claimer: tt.fields.Claimer,
AuthorizeSSHRenewFunc: tt.fields.AuthorizeSSHRenewFunc,
}
if err := c.AuthorizeSSHRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("Controller.AuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestDefaultAuthorizeRenew(t *testing.T) {
ctx := context.Background()
now := time.Now().Truncate(time.Second)
type args struct {
ctx context.Context
p *Controller
cert *x509.Certificate
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
}, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, false},
{"ok renew after expiry", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims),
}, &x509.Certificate{
NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute),
}}, false},
{"fail disabled", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
}, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, true},
{"fail not yet valid", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
}, &x509.Certificate{
NotBefore: now.Add(time.Hour),
NotAfter: now.Add(2 * time.Hour),
}}, true},
{"fail expired", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
}, &x509.Certificate{
NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute),
}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := DefaultAuthorizeRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("DefaultAuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestDefaultAuthorizeSSHRenew(t *testing.T) {
ctx := context.Background()
now := time.Now()
type args struct {
ctx context.Context
p *Controller
cert *ssh.Certificate
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
}, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, false},
{"ok renew after expiry", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims),
}, &ssh.Certificate{
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
}}, false},
{"fail disabled", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
}, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, true},
{"fail not yet valid", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
}, &ssh.Certificate{
ValidAfter: uint64(now.Add(time.Hour).Unix()),
ValidBefore: uint64(now.Add(2 * time.Hour).Unix()),
}}, true},
{"fail expired", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
}, &ssh.Certificate{
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := DefaultAuthorizeSSHRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("DefaultAuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

@ -0,0 +1,73 @@
package provisioner
import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
)
var (
// StepOIDRoot is the root OID for smallstep.
StepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64}
// StepOIDProvisioner is the OID for the provisioner extension.
StepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(StepOIDRoot, 1)...)
)
// Extension is the Go representation of the provisioner extension.
type Extension struct {
Type Type
Name string
CredentialID string
KeyValuePairs []string
}
type extensionASN1 struct {
Type int
Name []byte
CredentialID []byte
KeyValuePairs []string `asn1:"optional,omitempty"`
}
// Marshal marshals the extension using encoding/asn1.
func (e *Extension) Marshal() ([]byte, error) {
return asn1.Marshal(extensionASN1{
Type: int(e.Type),
Name: []byte(e.Name),
CredentialID: []byte(e.CredentialID),
KeyValuePairs: e.KeyValuePairs,
})
}
// ToExtension returns the pkix.Extension representation of the provisioner
// extension.
func (e *Extension) ToExtension() (pkix.Extension, error) {
b, err := e.Marshal()
if err != nil {
return pkix.Extension{}, err
}
return pkix.Extension{
Id: StepOIDProvisioner,
Value: b,
}, nil
}
// GetProvisionerExtension goes through all the certificate extensions and
// returns the provisioner extension (1.3.6.1.4.1.37476.9000.64.1).
func GetProvisionerExtension(cert *x509.Certificate) (*Extension, bool) {
for _, e := range cert.Extensions {
if e.Id.Equal(StepOIDProvisioner) {
var provisioner extensionASN1
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
return nil, false
}
return &Extension{
Type: Type(provisioner.Type),
Name: string(provisioner.Name),
CredentialID: string(provisioner.CredentialID),
KeyValuePairs: provisioner.KeyValuePairs,
}, true
}
}
return nil, false
}

@ -0,0 +1,158 @@
package provisioner
import (
"crypto/x509"
"crypto/x509/pkix"
"reflect"
"testing"
"go.step.sm/crypto/pemutil"
)
func TestExtension_Marshal(t *testing.T) {
type fields struct {
Type Type
Name string
CredentialID string
KeyValuePairs []string
}
tests := []struct {
name string
fields fields
want []byte
wantErr bool
}{
{"ok", fields{TypeJWK, "name", "credentialID", nil}, []byte{
0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
0x44,
}, false},
{"ok with pairs", fields{TypeJWK, "name", "credentialID", []string{"foo", "bar"}}, []byte{
0x30, 0x23, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
0x44, 0x30, 0x0a, 0x13, 0x03, 0x66, 0x6f, 0x6f,
0x13, 0x03, 0x62, 0x61, 0x72,
}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &Extension{
Type: tt.fields.Type,
Name: tt.fields.Name,
CredentialID: tt.fields.CredentialID,
KeyValuePairs: tt.fields.KeyValuePairs,
}
got, err := e.Marshal()
if (err != nil) != tt.wantErr {
t.Errorf("Extension.Marshal() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Extension.Marshal() = %x, want %v", got, tt.want)
}
})
}
}
func TestExtension_ToExtension(t *testing.T) {
type fields struct {
Type Type
Name string
CredentialID string
KeyValuePairs []string
}
tests := []struct {
name string
fields fields
want pkix.Extension
wantErr bool
}{
{"ok", fields{TypeJWK, "name", "credentialID", nil}, pkix.Extension{
Id: StepOIDProvisioner,
Value: []byte{
0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
0x44,
},
}, false},
{"ok empty pairs", fields{TypeJWK, "name", "credentialID", []string{}}, pkix.Extension{
Id: StepOIDProvisioner,
Value: []byte{
0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
0x44,
},
}, false},
{"ok with pairs", fields{TypeJWK, "name", "credentialID", []string{"foo", "bar"}}, pkix.Extension{
Id: StepOIDProvisioner,
Value: []byte{
0x30, 0x23, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
0x44, 0x30, 0x0a, 0x13, 0x03, 0x66, 0x6f, 0x6f,
0x13, 0x03, 0x62, 0x61, 0x72,
},
}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &Extension{
Type: tt.fields.Type,
Name: tt.fields.Name,
CredentialID: tt.fields.CredentialID,
KeyValuePairs: tt.fields.KeyValuePairs,
}
got, err := e.ToExtension()
if (err != nil) != tt.wantErr {
t.Errorf("Extension.ToExtension() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Extension.ToExtension() = %v, want %v", got, tt.want)
}
})
}
}
func TestGetProvisionerExtension(t *testing.T) {
mustCertificate := func(fn string) *x509.Certificate {
cert, err := pemutil.ReadCertificate(fn)
if err != nil {
t.Fatal(err)
}
return cert
}
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
args args
want *Extension
want1 bool
}{
{"ok", args{mustCertificate("testdata/certs/good-extension.crt")}, &Extension{
Type: TypeJWK,
Name: "mariano@smallstep.com",
CredentialID: "nvgnR8wSzpUlrt_tC3mvrhwhBx9Y7T1WL_JjcFVWYBQ",
}, true},
{"fail unmarshal", args{mustCertificate("testdata/certs/bad-extension.crt")}, nil, false},
{"missing extension", args{mustCertificate("testdata/certs/aws.crt")}, nil, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1 := GetProvisionerExtension(tt.args.cert)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("GetProvisionerExtension() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("GetProvisionerExtension() got1 = %v, want %v", got1, tt.want1)
}
})
}
}

@ -88,10 +88,9 @@ type GCP struct {
InstanceAge Duration `json:"instanceAge,omitempty"`
Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"`
claimer *Claimer
config *gcpConfig
keyStore *keyStore
audiences Audiences
ctl *Controller
}
// GetID returns the provisioner unique identifier. The name should uniquely
@ -194,8 +193,7 @@ func (p *GCP) GetIdentityToken(subject, caURL string) (string, error) {
}
// Init validates and initializes the GCP provisioner.
func (p *GCP) Init(config Config) error {
var err error
func (p *GCP) Init(config Config) (err error) {
switch {
case p.Type == "":
return errors.New("provisioner type cannot be empty")
@ -204,20 +202,18 @@ func (p *GCP) Init(config Config) error {
case p.InstanceAge.Value() < 0:
return errors.New("provisioner instanceAge cannot be negative")
}
// Initialize config
p.assertConfig()
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// Initialize key store
p.keyStore, err = newKeyStore(p.config.CertsURL)
if err != nil {
return err
if p.keyStore, err = newKeyStore(p.config.CertsURL); err != nil {
return
}
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
return nil
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.ctl, err = NewController(p, p.Claims, config)
return
}
// AuthorizeSign validates the given token and returns the sign options that
@ -266,22 +262,20 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
}
return append(so,
p,
templateOptions,
// modifiers / withOptions
newProvisionerExtensionOption(TypeGCP, p.Name, claims.Subject, "InstanceID", ce.InstanceID, "InstanceName", ce.InstanceName),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators
defaultPublicKeyValidator{},
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
), nil
}
// AuthorizeRenew returns an error if the renewal is disabled.
func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errs.Unauthorized("gcp.AuthorizeRenew; renew is disabled for gcp provisioner '%s'", p.GetName())
}
return nil
return p.ctl.AuthorizeRenew(ctx, cert)
}
// assertConfig initializes the config if it has not been initialized.
@ -328,7 +322,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
}
// validate audiences with the defaults
if !matchesAudience(claims.Audience, p.audiences.Sign) {
if !matchesAudience(claims.Audience, p.ctl.Audiences.Sign) {
return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)")
}
@ -383,7 +377,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() {
if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner '%s'", p.GetName())
}
claims, err := p.authorizeToken(token)
@ -431,11 +425,11 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validate user SignSSHOptions.
sshCertOptionsValidator(defaults),
// Set the validity bounds if not set.
&sshDefaultDuration{p.claimer},
&sshDefaultDuration{p.ctl.Claimer},
// Validate public key
&sshDefaultPublicKeyValidator{},
// Validate the validity period.
&sshCertValidityValidator{p.claimer},
&sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate
&sshCertDefaultValidator{},
), nil

@ -8,6 +8,7 @@ import (
"crypto/sha256"
"crypto/x509"
"encoding/hex"
"errors"
"fmt"
"net/http"
"net/http/httptest"
@ -16,10 +17,10 @@ import (
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
)
func TestGCP_Getters(t *testing.T) {
@ -390,8 +391,8 @@ func TestGCP_authorizeToken(t *testing.T) {
tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
@ -515,9 +516,9 @@ func TestGCP_AuthorizeSign(t *testing.T) {
code int
wantErr bool
}{
{"ok", p1, args{t1}, 5, http.StatusOK, false},
{"ok", p2, args{t2}, 10, http.StatusOK, false},
{"ok", p3, args{t3}, 5, http.StatusOK, false},
{"ok", p1, args{t1}, 6, http.StatusOK, false},
{"ok", p2, args{t2}, 11, http.StatusOK, false},
{"ok", p3, args{t3}, 6, http.StatusOK, false},
{"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true},
{"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true},
{"fail iss", p1, args{failIss}, 0, http.StatusUnauthorized, true},
@ -540,27 +541,28 @@ func TestGCP_AuthorizeSign(t *testing.T) {
t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return
case err != nil:
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
default:
assert.Len(t, tt.wantLen, got)
for _, o := range got {
switch v := o.(type) {
case *GCP:
case certificateOptionsFunc:
case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeGCP))
assert.Equals(t, v.Type, TypeGCP)
assert.Equals(t, v.Name, tt.gcp.GetName())
assert.Equals(t, v.CredentialID, tt.gcp.ServiceAccounts[0])
assert.Len(t, 4, v.KeyValuePairs)
case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tt.gcp.claimer.DefaultTLSCertDuration())
assert.Equals(t, time.Duration(v), tt.gcp.ctl.Claimer.DefaultTLSCertDuration())
case commonNameSliceValidator:
assert.Equals(t, []string(v), []string{"instance-name", "instance-id", "instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"})
case defaultPublicKeyValidator:
case *validityValidator:
assert.Equals(t, v.min, tt.gcp.claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.gcp.claimer.MaxTLSCertDuration())
assert.Equals(t, v.min, tt.gcp.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.gcp.ctl.Claimer.MaxTLSCertDuration())
case ipAddressesValidator:
assert.Equals(t, v, nil)
case emailAddressesValidator:
@ -570,7 +572,7 @@ func TestGCP_AuthorizeSign(t *testing.T) {
case dnsNamesValidator:
assert.Equals(t, []string(v), []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"})
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}
}
}
@ -595,7 +597,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
// disable sshCA
disable := false
p3.Claims = &Claims{EnableSSHCA: &disable}
p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
t1, err := generateGCPToken(p1.ServiceAccounts[0],
@ -622,7 +624,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err)
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedHostOptions := &SignSSHOptions{
CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
@ -677,8 +679,8 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
return
}
if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got)
} else if assert.NotNil(t, got) {
@ -698,6 +700,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
}
func TestGCP_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
p1, err := generateGCP()
assert.FatalError(t, err)
p2, err := generateGCP()
@ -706,7 +709,7 @@ func TestGCP_AuthorizeRenew(t *testing.T) {
// disable renewal
disable := true
p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
type args struct {
@ -719,15 +722,21 @@ func TestGCP_AuthorizeRenew(t *testing.T) {
code int
wantErr bool
}{
{"ok", p1, args{nil}, http.StatusOK, false},
{"fail/renewal-disabled", p2, args{nil}, http.StatusUnauthorized, true},
{"ok", p1, args{&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusOK, false},
{"fail/renewal-disabled", p2, args{&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusUnauthorized, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("GCP.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil {
sc, ok := err.(errs.StatusCoder)
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code)
}

@ -35,8 +35,7 @@ type JWK struct {
EncryptedKey string `json:"encryptedKey,omitempty"`
Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"`
claimer *Claimer
audiences Audiences
ctl *Controller
}
// GetID returns the provisioner unique identifier. The name and credential id
@ -98,13 +97,8 @@ func (p *JWK) Init(config Config) (err error) {
return errors.New("provisioner key cannot be empty")
}
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
p.audiences = config.Audiences
return err
p.ctl, err = NewController(p, p.Claims, config)
return
}
// authorizeToken performs common jwt authorization actions and returns the
@ -146,13 +140,13 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err
// AuthorizeRevoke returns an error if the provisioner does not have rights to
// revoke the certificate with serial number in the `sub` property.
func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.audiences.Revoke)
_, err := p.authorizeToken(token, p.ctl.Audiences.Revoke)
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke")
}
// AuthorizeSign validates the given token.
func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token, p.audiences.Sign)
claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign")
}
@ -176,15 +170,16 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
}
return []SignOption{
p,
templateOptions,
// modifiers / withOptions
newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators
commonNameValidator(claims.Subject),
defaultPublicKeyValidator{},
defaultSANsValidator(claims.SANs),
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
}, nil
}
@ -193,18 +188,15 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// revocation status. Just confirms that the provisioner that created the
// certificate was configured to allow renewals.
func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errs.Unauthorized("jwk.AuthorizeRenew; renew is disabled for jwk provisioner '%s'", p.GetName())
}
return nil
return p.ctl.AuthorizeRenew(ctx, cert)
}
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() {
if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner '%s'", p.GetName())
}
claims, err := p.authorizeToken(token, p.audiences.SSHSign)
claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHSign")
}
@ -261,11 +253,11 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
return append(signOptions,
// Set the validity bounds if not set.
&sshDefaultDuration{p.claimer},
&sshDefaultDuration{p.ctl.Claimer},
// Validate public key
&sshDefaultPublicKeyValidator{},
// Validate the validity period.
&sshCertValidityValidator{p.claimer},
&sshCertValidityValidator{p.ctl.Claimer},
// Require and validate all the default fields in the SSH certificate.
&sshCertDefaultValidator{},
), nil
@ -273,6 +265,6 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// AuthorizeSSHRevoke returns nil if the token is valid, false otherwise.
func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.audiences.SSHRevoke)
_, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke)
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke")
}

@ -6,15 +6,17 @@ import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"errors"
"fmt"
"net/http"
"strings"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
)
func TestJWK_Getters(t *testing.T) {
@ -76,13 +78,13 @@ func TestJWK_Init(t *testing.T) {
},
"fail-bad-claims": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{
p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences, Claims: &Claims{DefaultTLSDur: &Duration{0}}},
p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, Claims: &Claims{DefaultTLSDur: &Duration{0}}},
err: errors.New("claims: MinTLSCertDuration must be greater than 0"),
}
},
"ok": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{
p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences},
p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}},
}
},
}
@ -183,8 +185,8 @@ func TestJWK_authorizeToken(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
if got, err := tt.prov.authorizeToken(tt.args.token, testAudiences.Sign); err != nil {
if assert.NotNil(t, tt.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.HasPrefix(t, err.Error(), tt.err.Error())
}
@ -223,8 +225,8 @@ func TestJWK_AuthorizeRevoke(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRevoke(context.Background(), tt.args.token); err != nil {
if assert.NotNil(t, tt.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.HasPrefix(t, err.Error(), tt.err.Error())
}
@ -288,34 +290,35 @@ func TestJWK_AuthorizeSign(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignMethod)
if got, err := tt.prov.AuthorizeSign(ctx, tt.args.token); err != nil {
if assert.NotNil(t, tt.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.HasPrefix(t, err.Error(), tt.err.Error())
}
} else {
if assert.NotNil(t, got) {
assert.Len(t, 7, got)
assert.Len(t, 8, got)
for _, o := range got {
switch v := o.(type) {
case *JWK:
case certificateOptionsFunc:
case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeJWK))
assert.Equals(t, v.Type, TypeJWK)
assert.Equals(t, v.Name, tt.prov.GetName())
assert.Equals(t, v.CredentialID, tt.prov.Key.KeyID)
assert.Len(t, 0, v.KeyValuePairs)
case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration())
assert.Equals(t, time.Duration(v), tt.prov.ctl.Claimer.DefaultTLSCertDuration())
case commonNameValidator:
assert.Equals(t, string(v), "subject")
case defaultPublicKeyValidator:
case *validityValidator:
assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration())
assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration())
case defaultSANsValidator:
assert.Equals(t, []string(v), tt.sans)
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}
}
}
@ -325,6 +328,7 @@ func TestJWK_AuthorizeSign(t *testing.T) {
}
func TestJWK_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
p1, err := generateJWK()
assert.FatalError(t, err)
p2, err := generateJWK()
@ -333,7 +337,7 @@ func TestJWK_AuthorizeRenew(t *testing.T) {
// disable renewal
disable := true
p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
type args struct {
@ -346,16 +350,22 @@ func TestJWK_AuthorizeRenew(t *testing.T) {
code int
wantErr bool
}{
{"ok", p1, args{nil}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
{"ok", p1, args{&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusUnauthorized, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("JWK.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
}
})
@ -373,7 +383,7 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
// disable sshCA
disable := false
p2.Claims = &Claims{EnableSSHCA: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
jwk, err := decryptJSONWebKey(p1.EncryptedKey)
@ -402,8 +412,8 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err)
userDuration := p1.claimer.DefaultUserSSHCertDuration()
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration()
hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedUserOptions := &SignSSHOptions{
CertType: "user", Principals: []string{"name"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
@ -448,8 +458,8 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
return
}
if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got)
} else if assert.NotNil(t, got) {
@ -485,8 +495,8 @@ func TestJWK_AuthorizeSign_SSHOptions(t *testing.T) {
signer, err := generateJSONWebKey()
assert.FatalError(t, err)
userDuration := p1.claimer.DefaultUserSSHCertDuration()
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration()
hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedUserOptions := &SignSSHOptions{
CertType: "user", Principals: []string{"name"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
@ -613,8 +623,8 @@ func TestJWK_AuthorizeSSHRevoke(t *testing.T) {
tc := tt(t)
if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}

@ -42,16 +42,15 @@ type k8sSAPayload struct {
// entity trusted to make signature requests.
type K8sSA struct {
*base
ID string `json:"-"`
Type string `json:"type"`
Name string `json:"name"`
PubKeys []byte `json:"publicKeys,omitempty"`
Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"`
claimer *Claimer
audiences Audiences
ID string `json:"-"`
Type string `json:"type"`
Name string `json:"name"`
PubKeys []byte `json:"publicKeys,omitempty"`
Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"`
//kauthn kauthn.AuthenticationV1Interface
pubKeys []interface{}
ctl *Controller
}
// GetID returns the provisioner unique identifier. The name and credential id
@ -138,13 +137,8 @@ func (p *K8sSA) Init(config Config) (err error) {
p.kauthn = k8s.AuthenticationV1()
*/
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
p.audiences = config.Audiences
return err
p.ctl, err = NewController(p, p.Claims, config)
return
}
// authorizeToken performs common jwt authorization actions and returns the
@ -211,13 +205,13 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
// AuthorizeRevoke returns an error if the provisioner does not have rights to
// revoke the certificate with serial number in the `sub` property.
func (p *K8sSA) AuthorizeRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.audiences.Revoke)
_, err := p.authorizeToken(token, p.ctl.Audiences.Revoke)
return errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeRevoke")
}
// AuthorizeSign validates the given token.
func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token, p.audiences.Sign)
claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSign")
}
@ -237,30 +231,28 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
}
return []SignOption{
p,
templateOptions,
// modifiers / withOptions
newProvisionerExtensionOption(TypeK8sSA, p.Name, ""),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators
defaultPublicKeyValidator{},
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
}, nil
}
// AuthorizeRenew returns an error if the renewal is disabled.
func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errs.Unauthorized("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner '%s'", p.GetName())
}
return nil
return p.ctl.AuthorizeRenew(ctx, cert)
}
// AuthorizeSSHSign validates an request for an SSH certificate.
func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() {
if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName())
}
claims, err := p.authorizeToken(token, p.audiences.SSHSign)
claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSSHSign")
}
@ -282,11 +274,11 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
// Require type, key-id and principals in the SignSSHOptions.
&sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true},
// Set the validity bounds if not set.
&sshDefaultDuration{p.claimer},
&sshDefaultDuration{p.ctl.Claimer},
// Validate public key
&sshDefaultPublicKeyValidator{},
// Validate the validity period.
&sshCertValidityValidator{p.claimer},
&sshCertValidityValidator{p.ctl.Claimer},
// Require and validate all the default fields in the SSH certificate.
&sshCertDefaultValidator{},
), nil

@ -3,14 +3,16 @@ package provisioner
import (
"context"
"crypto/x509"
"errors"
"fmt"
"net/http"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
)
func TestK8sSA_Getters(t *testing.T) {
@ -116,8 +118,8 @@ func TestK8sSA_authorizeToken(t *testing.T) {
tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
@ -165,8 +167,8 @@ func TestK8sSA_AuthorizeRevoke(t *testing.T) {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -179,6 +181,7 @@ func TestK8sSA_AuthorizeRevoke(t *testing.T) {
}
func TestK8sSA_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
type test struct {
p *K8sSA
cert *x509.Certificate
@ -192,21 +195,27 @@ func TestK8sSA_AuthorizeRenew(t *testing.T) {
// disable renewal
disable := true
p.Claims = &Claims{DisableRenewal: &disable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
return test{
p: p,
cert: &x509.Certificate{},
p: p,
cert: &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
},
code: http.StatusUnauthorized,
err: errors.Errorf("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner '%s'", p.GetName()),
err: fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()),
}
},
"ok": func(t *testing.T) test {
p, err := generateK8sSA(nil)
assert.FatalError(t, err)
return test{
p: p,
cert: &x509.Certificate{},
p: p,
cert: &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
},
}
},
}
@ -214,8 +223,8 @@ func TestK8sSA_AuthorizeRenew(t *testing.T) {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -263,8 +272,8 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
tc := tt(t)
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
@ -274,24 +283,25 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
tot := 0
for _, o := range opts {
switch v := o.(type) {
case *K8sSA:
case certificateOptionsFunc:
case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeK8sSA))
assert.Equals(t, v.Type, TypeK8sSA)
assert.Equals(t, v.Name, tc.p.GetName())
assert.Equals(t, v.CredentialID, "")
assert.Len(t, 0, v.KeyValuePairs)
case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration())
assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration())
case defaultPublicKeyValidator:
case *validityValidator:
assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration())
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}
tot++
}
assert.Equals(t, tot, 5)
assert.Equals(t, tot, 6)
}
}
}
@ -313,13 +323,13 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
// disable sshCA
disable := false
p.Claims = &Claims{EnableSSHCA: &disable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()),
err: fmt.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()),
}
},
"fail/invalid-token": func(t *testing.T) test {
@ -350,8 +360,8 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
tc := tt(t)
if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
@ -365,13 +375,13 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
case *sshCertOptionsRequireValidator:
assert.Equals(t, v, &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true})
case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.claimer)
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
case *sshDefaultPublicKeyValidator:
case *sshCertDefaultValidator:
case *sshDefaultDuration:
assert.Equals(t, v.Claimer, tc.p.claimer)
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}
tot++
}

@ -34,19 +34,18 @@ const (
// https://signal.org/docs/specifications/xeddsa/#xeddsa and implemented by
// go.step.sm/crypto/x25519.
type Nebula struct {
ID string `json:"-"`
Type string `json:"type"`
Name string `json:"name"`
Roots []byte `json:"roots"`
Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"`
claimer *Claimer
caPool *nebula.NebulaCAPool
audiences Audiences
ID string `json:"-"`
Type string `json:"type"`
Name string `json:"name"`
Roots []byte `json:"roots"`
Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"`
caPool *nebula.NebulaCAPool
ctl *Controller
}
// Init verifies and initializes the Nebula provisioner.
func (p *Nebula) Init(config Config) error {
func (p *Nebula) Init(config Config) (err error) {
switch {
case p.Type == "":
return errors.New("provisioner type cannot be empty")
@ -56,19 +55,14 @@ func (p *Nebula) Init(config Config) error {
return errors.New("provisioner root(s) cannot be empty")
}
var err error
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
p.caPool, err = nebula.NewCAPoolFromBytes(p.Roots)
if err != nil {
return errs.InternalServer("failed to create ca pool: %v", err)
}
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
return nil
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.ctl, err = NewController(p, p.Claims, config)
return
}
// GetID returns the provisioner id.
@ -120,7 +114,7 @@ func (p *Nebula) GetEncryptedKey() (kid, key string, ok bool) {
// AuthorizeSign returns the list of SignOption for a Sign request.
func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
crt, claims, err := p.authorizeToken(token, p.audiences.Sign)
crt, claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
if err != nil {
return nil, err
}
@ -139,8 +133,9 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
data.SetToken(v)
}
// The Nebula certificate will be available using the template variable Crt.
// For example {{ .Crt.Details.Groups }} can be used to get all the groups.
// The Nebula certificate will be available using the template variable
// AuthorizationCrt. For example {{ .AuthorizationCrt.Details.Groups }} can
// be used to get all the groups.
data.SetAuthorizationCertificate(crt)
templateOptions, err := TemplateOptions(p.Options, data)
@ -149,11 +144,12 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
}
return []SignOption{
p,
templateOptions,
// modifiers / withOptions
newProvisionerExtensionOption(TypeNebula, p.Name, ""),
profileLimitDuration{
def: p.claimer.DefaultTLSCertDuration(),
def: p.ctl.Claimer.DefaultTLSCertDuration(),
notBefore: crt.Details.NotBefore,
notAfter: crt.Details.NotAfter,
},
@ -164,18 +160,18 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
IPs: crt.Details.Ips,
},
defaultPublicKeyValidator{},
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
}, nil
}
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
// Currently the Nebula provisioner only grants host SSH certificates.
func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() {
if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name)
}
crt, claims, err := p.authorizeToken(token, p.audiences.SSHSign)
crt, claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign)
if err != nil {
return nil, err
}
@ -253,11 +249,11 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti
return append(signOptions,
templateOptions,
// Checks the validity bounds, and set the validity if has not been set.
&sshLimitDuration{p.claimer, crt.Details.NotAfter},
&sshLimitDuration{p.ctl.Claimer, crt.Details.NotAfter},
// Validate public key.
&sshDefaultPublicKeyValidator{},
// Validate the validity period.
&sshCertValidityValidator{p.claimer},
&sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate
&sshCertDefaultValidator{},
), nil
@ -265,23 +261,20 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti
// AuthorizeRenew returns an error if the renewal is disabled.
func (p *Nebula) AuthorizeRenew(ctx context.Context, crt *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errs.Unauthorized("renew is disabled for nebula provisioner '%s'", p.GetName())
}
return nil
return p.ctl.AuthorizeRenew(ctx, crt)
}
// AuthorizeRevoke returns an error if the token is not valid.
func (p *Nebula) AuthorizeRevoke(ctx context.Context, token string) error {
return p.validateToken(token, p.audiences.Revoke)
return p.validateToken(token, p.ctl.Audiences.Revoke)
}
// AuthorizeSSHRevoke returns an error if SSH is disabled or the token is invalid.
func (p *Nebula) AuthorizeSSHRevoke(ctx context.Context, token string) error {
if !p.claimer.IsSSHCAEnabled() {
if !p.ctl.Claimer.IsSSHCAEnabled() {
return errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name)
}
if _, _, err := p.authorizeToken(token, p.audiences.SSHRevoke); err != nil {
if _, _, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke); err != nil {
return err
}
return nil

@ -327,7 +327,7 @@ func TestNebula_GetIDForToken(t *testing.T) {
func TestNebula_GetTokenID(t *testing.T) {
p, ca, signer := mustNebulaProvisioner(t)
c1, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"group"}, ca, signer)
t1 := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, c1, priv)
t1 := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, c1, priv)
_, claims, err := parseToken(t1)
if err != nil {
t.Fatal(err)
@ -441,8 +441,8 @@ func TestNebula_AuthorizeSign(t *testing.T) {
ctx := context.TODO()
p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, crt, priv)
okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), nil, crt, priv)
ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, crt, priv)
okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), nil, crt, priv)
pBadOptions, _, _ := mustNebulaProvisioner(t)
pBadOptions.caPool = p.caPool
@ -483,20 +483,20 @@ func TestNebula_AuthorizeSSHSign(t *testing.T) {
// Ok provisioner
p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{
ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{
CertType: "host",
KeyID: "test.lan",
Principals: []string{"test.lan", "10.1.0.1"},
}, crt, priv)
okNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), nil, crt, priv)
okWithValidity := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{
okNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), nil, crt, priv)
okWithValidity := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{
ValidAfter: NewTimeDuration(now().Add(1 * time.Hour)),
ValidBefore: NewTimeDuration(now().Add(10 * time.Hour)),
}, crt, priv)
failUserCert := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{
failUserCert := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{
CertType: "user",
}, crt, priv)
failPrincipals := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{
failPrincipals := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{
CertType: "host",
KeyID: "test.lan",
Principals: []string{"test.lan", "10.1.0.1", "foo.bar"},
@ -549,6 +549,8 @@ func TestNebula_AuthorizeSSHSign(t *testing.T) {
func TestNebula_AuthorizeRenew(t *testing.T) {
ctx := context.TODO()
now := time.Now().Truncate(time.Second)
// Ok provisioner
p, _, _ := mustNebulaProvisioner(t)
@ -567,8 +569,14 @@ func TestNebula_AuthorizeRenew(t *testing.T) {
args args
wantErr bool
}{
{"ok", p, args{ctx, &x509.Certificate{}}, false},
{"fail disabled", pDisabled, args{ctx, &x509.Certificate{}}, true},
{"ok", p, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, false},
{"fail disabled", pDisabled, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -584,12 +592,12 @@ func TestNebula_AuthorizeRevoke(t *testing.T) {
// Ok provisioner
p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Revoke[0], now(), nil, crt, priv)
ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Revoke[0], now(), nil, crt, priv)
// Fail different CA
nc, signer := mustNebulaCA(t)
crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer)
failToken := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Revoke[0], now(), nil, crt, priv)
failToken := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Revoke[0], now(), nil, crt, priv)
type args struct {
ctx context.Context
@ -618,12 +626,12 @@ func TestNebula_AuthorizeSSHRevoke(t *testing.T) {
// Ok provisioner
p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRevoke[0], now(), nil, crt, priv)
ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRevoke[0], now(), nil, crt, priv)
// Fail different CA
nc, signer := mustNebulaCA(t)
crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer)
failToken := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRevoke[0], now(), nil, crt, priv)
failToken := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRevoke[0], now(), nil, crt, priv)
// Provisioner with SSH disabled
var bFalse bool
@ -657,7 +665,7 @@ func TestNebula_AuthorizeSSHRevoke(t *testing.T) {
func TestNebula_AuthorizeSSHRenew(t *testing.T) {
p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRenew[0], now(), nil, crt, priv)
t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRenew[0], now(), nil, crt, priv)
type args struct {
ctx context.Context
@ -689,7 +697,7 @@ func TestNebula_AuthorizeSSHRenew(t *testing.T) {
func TestNebula_AuthorizeSSHRekey(t *testing.T) {
p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRekey[0], now(), nil, crt, priv)
t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRekey[0], now(), nil, crt, priv)
type args struct {
ctx context.Context
@ -726,20 +734,20 @@ func TestNebula_authorizeToken(t *testing.T) {
t1 := now()
p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1, nil, crt, priv)
okSSH := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], t1, &SignSSHOptions{
ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1, nil, crt, priv)
okSSH := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], t1, &SignSSHOptions{
CertType: "host",
KeyID: "test.lan",
Principals: []string{"test.lan"},
}, crt, priv)
okSSHNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], t1, nil, crt, priv)
okSSHNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], t1, nil, crt, priv)
// Token with errors
failNotBefore := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1.Add(1*time.Hour), []string{"10.1.0.1"}, crt, priv)
failIssuer := mustNebulaToken(t, "test.lan", "foo", p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
failNotBefore := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1.Add(1*time.Hour), []string{"10.1.0.1"}, crt, priv)
failIssuer := mustNebulaToken(t, "test.lan", "foo", p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
failAudience := mustNebulaToken(t, "test.lan", p.Name, "foo", t1, []string{"10.1.0.1"}, crt, priv)
failSubject := mustNebulaToken(t, "", p.Name, p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
failSubject := mustNebulaToken(t, "", p.Name, p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
// Not a nebula token
jwk, err := generateJSONWebKey()
@ -761,7 +769,7 @@ func TestNebula_authorizeToken(t *testing.T) {
IssuedAt: jose.NewNumericDate(t1),
NotBefore: jose.NewNumericDate(t1),
Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)),
Audience: []string{p.audiences.Sign[0]},
Audience: []string{p.ctl.Audiences.Sign[0]},
}
sshClaims := jose.Claims{
ID: "[REPLACEME]",
@ -770,7 +778,7 @@ func TestNebula_authorizeToken(t *testing.T) {
IssuedAt: jose.NewNumericDate(t1),
NotBefore: jose.NewNumericDate(t1),
Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)),
Audience: []string{p.audiences.SSHSign[0]},
Audience: []string{p.ctl.Audiences.SSHSign[0]},
}
type args struct {
@ -785,14 +793,14 @@ func TestNebula_authorizeToken(t *testing.T) {
want1 *jwtPayload
wantErr bool
}{
{"ok x509", p, args{ok, p.audiences.Sign}, crt, &jwtPayload{
{"ok x509", p, args{ok, p.ctl.Audiences.Sign}, crt, &jwtPayload{
Claims: x509Claims,
SANs: []string{"10.1.0.1"},
}, false},
{"ok x509 no sans", p, args{okNoSANs, p.audiences.Sign}, crt, &jwtPayload{
{"ok x509 no sans", p, args{okNoSANs, p.ctl.Audiences.Sign}, crt, &jwtPayload{
Claims: x509Claims,
}, false},
{"ok ssh", p, args{okSSH, p.audiences.SSHSign}, crt, &jwtPayload{
{"ok ssh", p, args{okSSH, p.ctl.Audiences.SSHSign}, crt, &jwtPayload{
Claims: sshClaims,
Step: &stepPayload{
SSH: &SignSSHOptions{
@ -802,16 +810,16 @@ func TestNebula_authorizeToken(t *testing.T) {
},
},
}, false},
{"ok ssh no principals", p, args{okSSHNoOptions, p.audiences.SSHSign}, crt, &jwtPayload{
{"ok ssh no principals", p, args{okSSHNoOptions, p.ctl.Audiences.SSHSign}, crt, &jwtPayload{
Claims: sshClaims,
}, false},
{"fail parse", p, args{"bad.token", p.audiences.Sign}, nil, nil, true},
{"fail header", p, args{simpleToken, p.audiences.Sign}, nil, nil, true},
{"fail verify", p2, args{ok, p.audiences.Sign}, nil, nil, true},
{"fail claims nbf", p, args{failNotBefore, p.audiences.Sign}, nil, nil, true},
{"fail claims iss", p, args{failIssuer, p.audiences.Sign}, nil, nil, true},
{"fail claims aud", p, args{failAudience, p.audiences.Sign}, nil, nil, true},
{"fail claims sub", p, args{failSubject, p.audiences.Sign}, nil, nil, true},
{"fail parse", p, args{"bad.token", p.ctl.Audiences.Sign}, nil, nil, true},
{"fail header", p, args{simpleToken, p.ctl.Audiences.Sign}, nil, nil, true},
{"fail verify", p2, args{ok, p.ctl.Audiences.Sign}, nil, nil, true},
{"fail claims nbf", p, args{failNotBefore, p.ctl.Audiences.Sign}, nil, nil, true},
{"fail claims iss", p, args{failIssuer, p.ctl.Audiences.Sign}, nil, nil, true},
{"fail claims aud", p, args{failAudience, p.ctl.Audiences.Sign}, nil, nil, true},
{"fail claims sub", p, args{failSubject, p.ctl.Audiences.Sign}, nil, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

@ -38,7 +38,7 @@ func (p *noop) Init(config Config) error {
}
func (p *noop) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
return []SignOption{}, nil
return []SignOption{p}, nil
}
func (p *noop) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {

@ -24,6 +24,6 @@ func Test_noop(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignMethod)
sigOptions, err := p.AuthorizeSign(ctx, "foo")
assert.Equals(t, []SignOption{}, sigOptions)
assert.Equals(t, []SignOption{&p}, sigOptions)
assert.Equals(t, nil, err)
}

@ -92,8 +92,7 @@ type OIDC struct {
Options *Options `json:"options,omitempty"`
configuration openIDConfiguration
keyStore *keyStore
claimer *Claimer
getIdentityFunc GetIdentityFunc
ctl *Controller
}
func sanitizeEmail(email string) string {
@ -172,11 +171,6 @@ func (o *OIDC) Init(config Config) (err error) {
}
}
// Update claims with global ones
if o.claimer, err = NewClaimer(o.Claims, config.Claims); err != nil {
return err
}
// Decode and validate openid-configuration endpoint
u, err := url.Parse(o.ConfigurationEndpoint)
if err != nil {
@ -201,13 +195,8 @@ func (o *OIDC) Init(config Config) (err error) {
return err
}
// Set the identity getter if it exists, otherwise use the default.
if config.GetIdentityFunc == nil {
o.getIdentityFunc = DefaultIdentityFunc
} else {
o.getIdentityFunc = config.GetIdentityFunc
}
return nil
o.ctl, err = NewController(o, o.Claims, config)
return
}
// ValidatePayload validates the given token payload.
@ -356,13 +345,14 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
}
return []SignOption{
o,
templateOptions,
// modifiers / withOptions
newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID),
profileDefaultDuration(o.claimer.DefaultTLSCertDuration()),
profileDefaultDuration(o.ctl.Claimer.DefaultTLSCertDuration()),
// validators
defaultPublicKeyValidator{},
newValidityValidator(o.claimer.MinTLSCertDuration(), o.claimer.MaxTLSCertDuration()),
newValidityValidator(o.ctl.Claimer.MinTLSCertDuration(), o.ctl.Claimer.MaxTLSCertDuration()),
}, nil
}
@ -371,15 +361,12 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
// revocation status. Just confirms that the provisioner that created the
// certificate was configured to allow renewals.
func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if o.claimer.IsDisableRenewal() {
return errs.Unauthorized("oidc.AuthorizeRenew; renew is disabled for oidc provisioner '%s'", o.GetName())
}
return nil
return o.ctl.AuthorizeRenew(ctx, cert)
}
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !o.claimer.IsSSHCAEnabled() {
if !o.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner '%s'", o.GetName())
}
claims, err := o.authorizeToken(token)
@ -394,7 +381,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
// Get the identity using either the default identityFunc or one injected
// externally. Note that the PreferredUsername might be empty.
// TBD: Would preferred_username present a safety issue here?
iden, err := o.getIdentityFunc(ctx, o, claims.Email)
iden, err := o.ctl.GetIdentity(ctx, claims.Email)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign")
}
@ -445,11 +432,11 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
return append(signOptions,
// Set the validity bounds if not set.
&sshDefaultDuration{o.claimer},
&sshDefaultDuration{o.ctl.Claimer},
// Validate public key
&sshDefaultPublicKeyValidator{},
// Validate the validity period.
&sshCertValidityValidator{o.claimer},
&sshCertValidityValidator{o.ctl.Claimer},
// Require all the fields in the SSH certificate
&sshCertDefaultValidator{},
), nil

@ -6,16 +6,17 @@ import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"errors"
"fmt"
"net/http"
"strings"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
)
func Test_openIDConfiguration_Validate(t *testing.T) {
@ -246,8 +247,8 @@ func TestOIDC_authorizeToken(t *testing.T) {
return
}
if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got)
} else {
@ -317,30 +318,31 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
return
}
if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got)
} else if assert.NotNil(t, got) {
assert.Len(t, 5, got)
assert.Len(t, 6, got)
for _, o := range got {
switch v := o.(type) {
case *OIDC:
case certificateOptionsFunc:
case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeOIDC))
assert.Equals(t, v.Type, TypeOIDC)
assert.Equals(t, v.Name, tt.prov.GetName())
assert.Equals(t, v.CredentialID, tt.prov.ClientID)
assert.Len(t, 0, v.KeyValuePairs)
case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration())
assert.Equals(t, time.Duration(v), tt.prov.ctl.Claimer.DefaultTLSCertDuration())
case defaultPublicKeyValidator:
case *validityValidator:
assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration())
assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration())
case emailOnlyIdentity:
assert.Equals(t, string(v), "name@smallstep.com")
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}
}
}
@ -402,8 +404,8 @@ func TestOIDC_AuthorizeRevoke(t *testing.T) {
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
return
} else if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
}
})
@ -411,6 +413,7 @@ func TestOIDC_AuthorizeRevoke(t *testing.T) {
}
func TestOIDC_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
p1, err := generateOIDC()
assert.FatalError(t, err)
p2, err := generateOIDC()
@ -419,7 +422,7 @@ func TestOIDC_AuthorizeRenew(t *testing.T) {
// disable renewal
disable := true
p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
type args struct {
@ -432,8 +435,14 @@ func TestOIDC_AuthorizeRenew(t *testing.T) {
code int
wantErr bool
}{
{"ok", p1, args{nil}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
{"ok", p1, args{&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusUnauthorized, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -441,8 +450,8 @@ func TestOIDC_AuthorizeRenew(t *testing.T) {
if (err != nil) != tt.wantErr {
t.Errorf("OIDC.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
}
})
@ -478,7 +487,7 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
// disable sshCA
disable := false
p6.Claims = &Claims{EnableSSHCA: &disable}
p6.claimer, err = NewClaimer(p6.Claims, globalProvisionerClaims)
p6.ctl.Claimer, err = NewClaimer(p6.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
// Update configuration endpoints and initialize
@ -494,10 +503,10 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
assert.FatalError(t, p4.Init(config))
assert.FatalError(t, p5.Init(config))
p4.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) {
p4.ctl.IdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) {
return &Identity{Usernames: []string{"max", "mariano"}}, nil
}
p5.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) {
p5.ctl.IdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) {
return nil, errors.New("force")
}
// Additional test needed for empty usernames and duplicate email and usernames
@ -527,8 +536,8 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err)
userDuration := p1.claimer.DefaultUserSSHCertDuration()
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration()
hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedUserOptions := &SignSSHOptions{
CertType: "user", Principals: []string{"name", "name@smallstep.com"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
@ -597,8 +606,8 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
return
}
if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got)
} else if assert.NotNil(t, got) {
@ -665,8 +674,8 @@ func TestOIDC_AuthorizeSSHRevoke(t *testing.T) {
if (err != nil) != tt.wantErr {
t.Errorf("OIDC.AuthorizeSSHRevoke() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
}
})

@ -6,7 +6,6 @@ import (
"encoding/json"
stderrors "errors"
"net/url"
"regexp"
"strings"
"github.com/pkg/errors"
@ -47,6 +46,7 @@ var ErrAllowTokenReuse = stderrors.New("allow token reuse")
// Audiences stores all supported audiences by request type.
type Audiences struct {
Sign []string
Renew []string
Revoke []string
SSHSign []string
SSHRevoke []string
@ -57,6 +57,7 @@ type Audiences struct {
// All returns all supported audiences across all request types in one list.
func (a Audiences) All() (auds []string) {
auds = a.Sign
auds = append(auds, a.Renew...)
auds = append(auds, a.Revoke...)
auds = append(auds, a.SSHSign...)
auds = append(auds, a.SSHRevoke...)
@ -70,6 +71,7 @@ func (a Audiences) All() (auds []string) {
func (a Audiences) WithFragment(fragment string) Audiences {
ret := Audiences{
Sign: make([]string, len(a.Sign)),
Renew: make([]string, len(a.Renew)),
Revoke: make([]string, len(a.Revoke)),
SSHSign: make([]string, len(a.SSHSign)),
SSHRevoke: make([]string, len(a.SSHRevoke)),
@ -83,6 +85,13 @@ func (a Audiences) WithFragment(fragment string) Audiences {
ret.Sign[i] = s
}
}
for i, s := range a.Renew {
if u, err := url.Parse(s); err == nil {
ret.Renew[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String()
} else {
ret.Renew[i] = s
}
}
for i, s := range a.Revoke {
if u, err := url.Parse(s); err == nil {
ret.Revoke[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String()
@ -210,6 +219,12 @@ type Config struct {
// GetIdentityFunc is a function that returns an identity that will be
// used by the provisioner to populate certificate attributes.
GetIdentityFunc GetIdentityFunc
// AuthorizeRenewFunc is a function that returns nil if a given X.509
// certificate can be renewed.
AuthorizeRenewFunc AuthorizeRenewFunc
// AuthorizeSSHRenewFunc is a function that returns nil if a given SSH
// certificate can be renewed.
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
}
type provisioner struct {
@ -278,32 +293,6 @@ func (l *List) UnmarshalJSON(data []byte) error {
return nil
}
var sshUserRegex = regexp.MustCompile("^[a-z][-a-z0-9_]*$")
// SanitizeSSHUserPrincipal grabs an email or a string with the format
// local@domain and returns a sanitized version of the local, valid to be used
// as a user name. If the email starts with a letter between a and z, the
// resulting string will match the regular expression `^[a-z][-a-z0-9_]*$`.
func SanitizeSSHUserPrincipal(email string) string {
if i := strings.LastIndex(email, "@"); i >= 0 {
email = email[:i]
}
return strings.Map(func(r rune) rune {
switch {
case r >= 'a' && r <= 'z':
return r
case r >= '0' && r <= '9':
return r
case r == '-':
return '-'
case r == '.': // drop dots
return -1
default:
return '_'
}
}, strings.ToLower(email))
}
type base struct{}
// AuthorizeSign returns an unimplemented error. Provisioners should overwrite
@ -348,66 +337,12 @@ func (b *base) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certif
return nil, nil, errs.Unauthorized("provisioner.AuthorizeSSHRekey not implemented")
}
// Identity is the type representing an externally supplied identity that is used
// by provisioners to populate certificate fields.
type Identity struct {
Usernames []string `json:"usernames"`
Permissions `json:"permissions"`
}
// Permissions defines extra extensions and critical options to grant to an SSH certificate.
type Permissions struct {
Extensions map[string]string `json:"extensions"`
CriticalOptions map[string]string `json:"criticalOptions"`
}
// GetIdentityFunc is a function that returns an identity.
type GetIdentityFunc func(ctx context.Context, p Interface, email string) (*Identity, error)
// DefaultIdentityFunc return a default identity depending on the provisioner
// type. For OIDC email is always present and the usernames might
// contain empty strings.
func DefaultIdentityFunc(ctx context.Context, p Interface, email string) (*Identity, error) {
switch k := p.(type) {
case *OIDC:
// OIDC principals would be:
// ~~1. Preferred usernames.~~ Note: Under discussion, currently disabled
// 2. Sanitized local.
// 3. Raw local (if different).
// 4. Email address.
name := SanitizeSSHUserPrincipal(email)
if !sshUserRegex.MatchString(name) {
return nil, errors.Errorf("invalid principal '%s' from email '%s'", name, email)
}
usernames := []string{name}
if i := strings.LastIndex(email, "@"); i >= 0 {
usernames = append(usernames, email[:i])
}
usernames = append(usernames, email)
return &Identity{
Usernames: SanitizeStringSlices(usernames),
}, nil
default:
return nil, errors.Errorf("provisioner type '%T' not supported by identity function", k)
}
}
// SanitizeStringSlices removes duplicated an empty strings.
func SanitizeStringSlices(original []string) []string {
output := []string{}
seen := make(map[string]struct{})
for _, entry := range original {
if entry == "" {
continue
}
if _, value := seen[entry]; !value {
seen[entry] = struct{}{}
output = append(output, entry)
}
}
return output
}
// MockProvisioner for testing
type MockProvisioner struct {
Mret1, Mret2, Mret3 interface{}

@ -2,13 +2,14 @@ package provisioner
import (
"context"
"errors"
"net/http"
"testing"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
)
func TestType_String(t *testing.T) {
@ -240,8 +241,8 @@ func TestUnimplementedMethods(t *testing.T) {
default:
t.Errorf("unexpected method %s", tt.method)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), http.StatusUnauthorized)
assert.Equals(t, err.Error(), msg)
})

@ -11,28 +11,30 @@ import (
// SCEP provisioning flow
type SCEP struct {
*base
ID string `json:"-"`
Type string `json:"type"`
Name string `json:"name"`
ID string `json:"-"`
Type string `json:"type"`
Name string `json:"name"`
ForceCN bool `json:"forceCN,omitempty"`
ChallengePassword string `json:"challenge,omitempty"`
Capabilities []string `json:"capabilities,omitempty"`
// IncludeRoot makes the provisioner return the CA root in addition to the
// intermediate in the GetCACerts response
IncludeRoot bool `json:"includeRoot,omitempty"`
// MinimumPublicKeyLength is the minimum length for public keys in CSRs
MinimumPublicKeyLength int `json:"minimumPublicKeyLength,omitempty"`
// Numerical identifier for the ContentEncryptionAlgorithm as defined in github.com/mozilla-services/pkcs7
// at https://github.com/mozilla-services/pkcs7/blob/33d05740a3526e382af6395d3513e73d4e66d1cb/encrypt.go#L63
// Defaults to 0, being DES-CBC
EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"`
Options *Options `json:"options,omitempty"`
Claims *Claims `json:"claims,omitempty"`
claimer *Claimer
EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"`
Options *Options `json:"options,omitempty"`
Claims *Claims `json:"claims,omitempty"`
secretChallengePassword string
encryptionAlgorithm int
ctl *Controller
}
// GetID returns the provisioner unique identifier.
@ -77,7 +79,7 @@ func (s *SCEP) GetOptions() *Options {
// DefaultTLSCertDuration returns the default TLS cert duration enforced by
// the provisioner.
func (s *SCEP) DefaultTLSCertDuration() time.Duration {
return s.claimer.DefaultTLSCertDuration()
return s.ctl.Claimer.DefaultTLSCertDuration()
}
// Init initializes and validates the fields of a SCEP type.
@ -90,11 +92,6 @@ func (s *SCEP) Init(config Config) (err error) {
return errors.New("provisioner name cannot be empty")
}
// Update claims with global ones
if s.claimer, err = NewClaimer(s.Claims, config.Claims); err != nil {
return err
}
// Mask the actual challenge value, so it won't be marshaled
s.secretChallengePassword = s.ChallengePassword
s.ChallengePassword = "*** redacted ***"
@ -115,7 +112,8 @@ func (s *SCEP) Init(config Config) (err error) {
// TODO: add other, SCEP specific, options?
return err
s.ctl, err = NewController(s, s.Claims, config)
return
}
// AuthorizeSign does not do any verification, because all verification is handled
@ -123,13 +121,14 @@ func (s *SCEP) Init(config Config) (err error) {
// on the resulting certificate.
func (s *SCEP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
return []SignOption{
s,
// modifiers / withOptions
newProvisionerExtensionOption(TypeSCEP, s.Name, ""),
newForceCNOption(s.ForceCN),
profileDefaultDuration(s.claimer.DefaultTLSCertDuration()),
profileDefaultDuration(s.ctl.Claimer.DefaultTLSCertDuration()),
// validators
newPublicKeyMinimumLengthValidator(s.MinimumPublicKeyLength),
newValidityValidator(s.claimer.MinTLSCertDuration(), s.claimer.MaxTLSCertDuration()),
newValidityValidator(s.ctl.Claimer.MinTLSCertDuration(), s.ctl.Claimer.MaxTLSCertDuration()),
}, nil
}

@ -6,7 +6,6 @@ import (
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/json"
"net"
"net/http"
@ -14,7 +13,6 @@ import (
"reflect"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/x509util"
@ -404,17 +402,12 @@ func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error {
return nil
}
var (
stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64}
stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...)
)
type stepProvisionerASN1 struct {
Type int
Name []byte
CredentialID []byte
KeyValuePairs []string `asn1:"optional,omitempty"`
}
// type stepProvisionerASN1 struct {
// Type int
// Name []byte
// CredentialID []byte
// KeyValuePairs []string `asn1:"optional,omitempty"`
// }
type forceCNOption struct {
ForceCN bool
@ -441,23 +434,22 @@ func (o *forceCNOption) Modify(cert *x509.Certificate, _ SignOptions) error {
}
type provisionerExtensionOption struct {
Type int
Name string
CredentialID string
KeyValuePairs []string
Extension
}
func newProvisionerExtensionOption(typ Type, name, credentialID string, keyValuePairs ...string) *provisionerExtensionOption {
return &provisionerExtensionOption{
Type: int(typ),
Name: name,
CredentialID: credentialID,
KeyValuePairs: keyValuePairs,
Extension: Extension{
Type: typ,
Name: name,
CredentialID: credentialID,
KeyValuePairs: keyValuePairs,
},
}
}
func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOptions) error {
ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID, o.KeyValuePairs...)
ext, err := o.ToExtension()
if err != nil {
return errs.NewError(http.StatusInternalServerError, err, "error creating certificate")
}
@ -471,20 +463,3 @@ func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOption
cert.ExtraExtensions = append([]pkix.Extension{ext}, cert.ExtraExtensions...)
return nil
}
func createProvisionerExtension(typ int, name, credentialID string, keyValuePairs ...string) (pkix.Extension, error) {
b, err := asn1.Marshal(stepProvisionerASN1{
Type: typ,
Name: []byte(name),
CredentialID: []byte(credentialID),
KeyValuePairs: keyValuePairs,
})
if err != nil {
return pkix.Extension{}, errors.Wrap(err, "error marshaling provisioner extension")
}
return pkix.Extension{
Id: stepOIDProvisioner,
Critical: false,
Value: b,
}, nil
}

@ -636,18 +636,18 @@ func Test_newProvisionerExtension_Option(t *testing.T) {
valid: func(cert *x509.Certificate) {
if assert.Len(t, 1, cert.ExtraExtensions) {
ext := cert.ExtraExtensions[0]
assert.Equals(t, ext.Id, stepOIDProvisioner)
assert.Equals(t, ext.Id, StepOIDProvisioner)
}
},
}
},
"ok/prepend": func() test {
return test{
cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{{Id: stepOIDProvisioner, Critical: true}, {Id: []int{1, 2, 3}}}},
cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{{Id: StepOIDProvisioner, Critical: true}, {Id: []int{1, 2, 3}}}},
valid: func(cert *x509.Certificate) {
if assert.Len(t, 3, cert.ExtraExtensions) {
ext := cert.ExtraExtensions[0]
assert.Equals(t, ext.Id, stepOIDProvisioner)
assert.Equals(t, ext.Id, StepOIDProvisioner)
assert.False(t, ext.Critical)
}
},

@ -685,7 +685,7 @@ func Test_sshCertDefaultValidator_Valid(t *testing.T) {
func Test_sshCertValidityValidator(t *testing.T) {
p, err := generateX5C(nil)
assert.FatalError(t, err)
v := sshCertValidityValidator{p.claimer}
v := sshCertValidityValidator{p.ctl.Claimer}
n := now()
tests := []struct {
name string
@ -806,7 +806,7 @@ func Test_sshValidityModifier(t *testing.T) {
tests := map[string]func() test{
"fail/type-not-set": func() test {
return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(6 * time.Hour)},
svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(6 * time.Hour)},
cert: &ssh.Certificate{
ValidAfter: uint64(n.Unix()),
ValidBefore: uint64(n.Add(8 * time.Hour).Unix()),
@ -816,7 +816,7 @@ func Test_sshValidityModifier(t *testing.T) {
},
"fail/type-not-recognized": func() test {
return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(6 * time.Hour)},
svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(6 * time.Hour)},
cert: &ssh.Certificate{
CertType: 4,
ValidAfter: uint64(n.Unix()),
@ -827,7 +827,7 @@ func Test_sshValidityModifier(t *testing.T) {
},
"fail/requested-validAfter-after-limit": func() test {
return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(1 * time.Hour)},
svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(1 * time.Hour)},
cert: &ssh.Certificate{
CertType: 1,
ValidAfter: uint64(n.Add(2 * time.Hour).Unix()),
@ -838,7 +838,7 @@ func Test_sshValidityModifier(t *testing.T) {
},
"fail/requested-validBefore-after-limit": func() test {
return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(1 * time.Hour)},
svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(1 * time.Hour)},
cert: &ssh.Certificate{
CertType: 1,
ValidAfter: uint64(n.Unix()),
@ -850,7 +850,7 @@ func Test_sshValidityModifier(t *testing.T) {
"ok/no-limit": func() test {
va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix())
return test{
svm: &sshLimitDuration{Claimer: p.claimer},
svm: &sshLimitDuration{Claimer: p.ctl.Claimer},
cert: &ssh.Certificate{
CertType: 1,
},
@ -863,7 +863,7 @@ func Test_sshValidityModifier(t *testing.T) {
"ok/defaults": func() test {
va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix())
return test{
svm: &sshLimitDuration{Claimer: p.claimer},
svm: &sshLimitDuration{Claimer: p.ctl.Claimer},
cert: &ssh.Certificate{
CertType: 1,
},
@ -876,7 +876,7 @@ func Test_sshValidityModifier(t *testing.T) {
"ok/valid-requested-validBefore": func() test {
va, vb := uint64(n.Unix()), uint64(n.Add(2*time.Hour).Unix())
return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(3 * time.Hour)},
svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(3 * time.Hour)},
cert: &ssh.Certificate{
CertType: 1,
ValidAfter: va,
@ -891,7 +891,7 @@ func Test_sshValidityModifier(t *testing.T) {
"ok/empty-requested-validBefore-limit-after-default": func() test {
va := uint64(n.Unix())
return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(24 * time.Hour)},
svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(24 * time.Hour)},
cert: &ssh.Certificate{
CertType: 1,
ValidAfter: va,
@ -905,7 +905,7 @@ func Test_sshValidityModifier(t *testing.T) {
"ok/empty-requested-validBefore-limit-before-default": func() test {
va := uint64(n.Unix())
return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(3 * time.Hour)},
svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(3 * time.Hour)},
cert: &ssh.Certificate{
CertType: 1,
ValidAfter: va,

@ -29,8 +29,7 @@ type SSHPOP struct {
Type string `json:"type"`
Name string `json:"name"`
Claims *Claims `json:"claims,omitempty"`
claimer *Claimer
audiences Audiences
ctl *Controller
sshPubKeys *SSHKeys
}
@ -83,7 +82,7 @@ func (p *SSHPOP) GetEncryptedKey() (string, string, bool) {
}
// Init initializes and validates the fields of a SSHPOP type.
func (p *SSHPOP) Init(config Config) error {
func (p *SSHPOP) Init(config Config) (err error) {
switch {
case p.Type == "":
return errors.New("provisioner type cannot be empty")
@ -93,15 +92,11 @@ func (p *SSHPOP) Init(config Config) error {
return errors.New("provisioner public SSH validation keys cannot be empty")
}
// Update claims with global ones
var err error
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.sshPubKeys = config.SSHKeys
return nil
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.ctl, err = NewController(p, p.Claims, config)
return
}
// authorizeToken performs common jwt authorization actions and returns the
@ -109,7 +104,7 @@ func (p *SSHPOP) Init(config Config) error {
// e.g. a Sign request will auth/validate different fields than a Revoke request.
//
// Checking for certificate revocation has been moved to the authority package.
func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayload, error) {
func (p *SSHPOP) authorizeToken(token string, audiences []string, checkValidity bool) (*sshPOPPayload, error) {
sshCert, jwt, err := ExtractSSHPOPCert(token)
if err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err,
@ -117,13 +112,18 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
}
// Check validity period of the certificate.
n := time.Now()
if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) {
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future")
}
if sshCert.ValidBefore != 0 && time.Unix(int64(sshCert.ValidBefore), 0).Before(n) {
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past")
//
// Controller.AuthorizeSSHRenew will validate this on the renewal flow.
if checkValidity {
unixNow := time.Now().Unix()
if after := int64(sshCert.ValidAfter); after < 0 || unixNow < int64(sshCert.ValidAfter) {
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future")
}
if before := int64(sshCert.ValidBefore); sshCert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) {
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past")
}
}
sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey)
if !ok {
return nil, errs.InternalServer("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey")
@ -186,7 +186,7 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
// AuthorizeSSHRevoke validates the authorization token and extracts/validates
// the SSH certificate from the ssh-pop header.
func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
claims, err := p.authorizeToken(token, p.audiences.SSHRevoke)
claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke, true)
if err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke")
}
@ -199,22 +199,20 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
// AuthorizeSSHRenew validates the authorization token and extracts/validates
// the SSH certificate from the ssh-pop header.
func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
claims, err := p.authorizeToken(token, p.audiences.SSHRenew)
claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRenew, false)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew")
}
if claims.sshCert.CertType != ssh.HostCert {
return nil, errs.BadRequest("sshpop certificate must be a host ssh certificate")
}
return claims.sshCert, nil
return claims.sshCert, p.ctl.AuthorizeSSHRenew(ctx, claims.sshCert)
}
// AuthorizeSSHRekey validates the authorization token and extracts/validates
// the SSH certificate from the ssh-pop header.
func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) {
claims, err := p.authorizeToken(token, p.audiences.SSHRekey)
claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRekey, true)
if err != nil {
return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey")
}
@ -225,11 +223,10 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert
// Validate public key
&sshDefaultPublicKeyValidator{},
// Validate the validity period.
&sshCertValidityValidator{p.claimer},
&sshCertValidityValidator{p.ctl.Claimer},
// Require and validate all the default fields in the SSH certificate.
&sshCertDefaultValidator{},
}, nil
}
// ExtractSSHPOPCert parses a JWT and extracts and loads the SSH Certificate

@ -5,16 +5,19 @@ import (
"crypto"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"net/http"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil"
"golang.org/x/crypto/ssh"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
)
func TestSSHPOP_Getters(t *testing.T) {
@ -38,6 +41,7 @@ func TestSSHPOP_Getters(t *testing.T) {
}
func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) {
now := time.Now()
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0)
if err != nil {
return nil, nil, err
@ -46,6 +50,12 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate,
if err != nil {
return nil, nil, err
}
if cert.ValidAfter == 0 {
cert.ValidAfter = uint64(now.Unix())
}
if cert.ValidBefore == 0 {
cert.ValidBefore = uint64(now.Add(time.Hour).Unix())
}
if err := cert.SignCert(rand.Reader, signer); err != nil {
return nil, nil, err
}
@ -207,9 +217,9 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign, true); err != nil {
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -279,8 +289,8 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -360,8 +370,8 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) {
tc := tt(t)
if cert, err := tc.p.AuthorizeSSHRenew(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
@ -442,8 +452,8 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
tc := tt(t)
if cert, opts, err := tc.p.AuthorizeSSHRekey(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
@ -455,9 +465,9 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
case *sshDefaultPublicKeyValidator:
case *sshCertDefaultValidator:
case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.claimer)
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}
}
assert.Equals(t, tc.cert.Nonce, cert.Nonce)

@ -0,0 +1,21 @@
-----BEGIN CERTIFICATE-----
MIIDeTCCAx+gAwIBAgIRAOTItW2pYuSU+PkmLW090iUwCgYIKoZIzj0EAwIwJDEi
MCAGA1UEAxMZU21hbGxzdGVwIEludGVybWVkaWF0ZSBDQTAeFw0yMjAzMTEyMjUy
MjBaFw0yMjAzMTIyMjUzMjBaMIGcMQswCQYDVQQGEwJDSDETMBEGA1UECBMKQ2Fs
aWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEYMBYGA1UECRMPMSBUaGUg
U3RyZWV0IFN0MRMwEQYDVQQKDAo8bm8gdmFsdWU+MRYwFAYDVQQLEw1TbWFsbHN0
ZXAgRW5nMRkwFwYDVQQDDBB0ZXN0QGV4YW1wbGUuY29tMFkwEwYHKoZIzj0CAQYI
KoZIzj0DAQcDQgAE/9vvOZ1Zzysnf3VeGyotMJEMZdAborB36Ah5QL/3yQNMRWIc
pv9Dwx19pHw7SquVE8jIaPPJSjaeWnfMPDYDxaOCAbcwggGzMA4GA1UdDwEB/wQE
AwIHgDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIw
ADAdBgNVHQ4EFgQUkJUg6AsqWlqTZt6BHidRMwh1vKYwHwYDVR0jBBgwFoAUDpTg
d3VFCn6e71wXcwbDCURBomUwgZoGCCsGAQUFBwEBBIGNMIGKMBcGCCsGAQUFBzAB
hgtodHRwczovL2ZvbzBvBggrBgEFBQcwAoZjaHR0cHM6Ly9jYS5zbWFsbHN0ZXAu
Y29tOjkwMDAvcm9vdC9hNzhhODUwMDI1YzBjMjM0Mzg1ZWRhMjNkNzE5Mjk2NGNh
NTZhYTlkNzI3ZjUzNTY1M2IwYWZiODFjMWUwNTU5MBsGA1UdEQQUMBKBEHRlc3RA
ZXhhbXBsZS5jb20wIAYDVR0gBBkwFzALBglghkgBhv1sAQEwCAYGZ4EMAQICMD8G
A1UdHwQ4MDYwNKAyoDCGLmh0dHA6Ly9jcmwzLmRpZ2ljZXJ0LmNvbS9zaGEyLWV2
LXNlcnZlci1nMy5jcmwwFwYMKwYBBAGCpGTGKEABBAdmb29vYmFyMAoGCCqGSM49
BAMCA0gAMEUCIQCWYqOuk4bLkVVeHvo3P8TlJJ3fw6ijDDLstvdrQqAl5wIgEjSY
wVcR649Oc8PJGh/43Kpx0+4OTYPQrD/JqphVF7g=
-----END CERTIFICATE-----

@ -0,0 +1,22 @@
-----BEGIN CERTIFICATE-----
MIIDujCCA2GgAwIBAgIRAM5celDKTTqAGycljO7FZdEwCgYIKoZIzj0EAwIwJDEi
MCAGA1UEAxMZU21hbGxzdGVwIEludGVybWVkaWF0ZSBDQTAeFw0yMjAzMTEyMjQx
MDRaFw0yMjAzMTIyMjQyMDRaMIGcMQswCQYDVQQGEwJDSDETMBEGA1UECBMKQ2Fs
aWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEYMBYGA1UECRMPMSBUaGUg
U3RyZWV0IFN0MRMwEQYDVQQKDAo8bm8gdmFsdWU+MRYwFAYDVQQLEw1TbWFsbHN0
ZXAgRW5nMRkwFwYDVQQDDBB0ZXN0QGV4YW1wbGUuY29tMFkwEwYHKoZIzj0CAQYI
KoZIzj0DAQcDQgAEkXffZYlSJRMxJrZHmUpEMC4jQYCkF86mLJY0iLZ8k00N/xF0
4rAGwzTU/l9tfRpNl+z/XfMMWPXS0Q8NU/o4S6OCAfkwggH1MA4GA1UdDwEB/wQE
AwIHgDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIw
ADAdBgNVHQ4EFgQUL3sSlYW8Tf2l2P+gFTdn5wsUjfgwHwYDVR0jBBgwFoAUDpTg
d3VFCn6e71wXcwbDCURBomUwgZoGCCsGAQUFBwEBBIGNMIGKMBcGCCsGAQUFBzAB
hgtodHRwczovL2ZvbzBvBggrBgEFBQcwAoZjaHR0cHM6Ly9jYS5zbWFsbHN0ZXAu
Y29tOjkwMDAvcm9vdC9hNzhhODUwMDI1YzBjMjM0Mzg1ZWRhMjNkNzE5Mjk2NGNh
NTZhYTlkNzI3ZjUzNTY1M2IwYWZiODFjMWUwNTU5MBsGA1UdEQQUMBKBEHRlc3RA
ZXhhbXBsZS5jb20wIAYDVR0gBBkwFzALBglghkgBhv1sAQEwCAYGZ4EMAQICMD8G
A1UdHwQ4MDYwNKAyoDCGLmh0dHA6Ly9jcmwzLmRpZ2ljZXJ0LmNvbS9zaGEyLWV2
LXNlcnZlci1nMy5jcmwwWQYMKwYBBAGCpGTGKEABBEkwRwIBAQQVbWFyaWFub0Bz
bWFsbHN0ZXAuY29tBCtudmduUjh3U3pwVWxydF90QzNtdnJod2hCeDlZN1QxV0xf
SmpjRlZXWUJRMAoGCCqGSM49BAMCA0cAMEQCIE6umrhSbeQWWVK5cWBvXj5c0cGB
bUF0rNw/dsaCaWcwAiAKSkmjhsC63DVPXPCNUki90YgVovO69foO1ZaB43lx5w==
-----END CERTIFICATE-----

@ -24,20 +24,22 @@ import (
)
var (
defaultDisableRenewal = false
defaultEnableSSHCA = true
globalProvisionerClaims = Claims{
MinTLSDur: &Duration{5 * time.Minute},
MaxTLSDur: &Duration{24 * time.Hour},
DefaultTLSDur: &Duration{24 * time.Hour},
DisableRenewal: &defaultDisableRenewal,
MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs
MaxUserSSHDur: &Duration{Duration: 24 * time.Hour},
DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour},
MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs
MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
EnableSSHCA: &defaultEnableSSHCA,
defaultDisableRenewal = false
defaultAllowRenewAfterExpiry = false
defaultEnableSSHCA = true
globalProvisionerClaims = Claims{
MinTLSDur: &Duration{5 * time.Minute},
MaxTLSDur: &Duration{24 * time.Hour},
DefaultTLSDur: &Duration{24 * time.Hour},
MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs
MaxUserSSHDur: &Duration{Duration: 24 * time.Hour},
DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour},
MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs
MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
EnableSSHCA: &defaultEnableSSHCA,
DisableRenewal: &defaultDisableRenewal,
AllowRenewAfterExpiry: &defaultAllowRenewAfterExpiry,
}
testAudiences = Audiences{
Sign: []string{"https://ca.smallstep.com/1.0/sign", "https://ca.smallstep.com/sign"},
@ -172,19 +174,18 @@ func generateJWK() (*JWK, error) {
if err != nil {
return nil, err
}
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
return &JWK{
p := &JWK{
Name: name,
Type: "JWK",
Key: &public,
EncryptedKey: encrypted,
Claims: &globalProvisionerClaims,
audiences: testAudiences,
claimer: claimer,
}, nil
}
p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences,
})
return p, err
}
func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
@ -205,23 +206,21 @@ func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
return nil, err
}
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
pubKeys := []interface{}{fooPub, barPub}
if inputPubKey != nil {
pubKeys = append(pubKeys, inputPubKey)
}
return &K8sSA{
Name: K8sSAName,
Type: "K8sSA",
Claims: &globalProvisionerClaims,
audiences: testAudiences,
claimer: claimer,
pubKeys: pubKeys,
}, nil
p := &K8sSA{
Name: K8sSAName,
Type: "K8sSA",
Claims: &globalProvisionerClaims,
pubKeys: pubKeys,
}
p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences,
})
return p, err
}
func generateSSHPOP() (*SSHPOP, error) {
@ -229,11 +228,6 @@ func generateSSHPOP() (*SSHPOP, error) {
if err != nil {
return nil, err
}
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
userB, err := os.ReadFile("./testdata/certs/ssh_user_ca_key.pub")
if err != nil {
return nil, err
@ -251,17 +245,19 @@ func generateSSHPOP() (*SSHPOP, error) {
return nil, err
}
return &SSHPOP{
Name: name,
Type: "SSHPOP",
Claims: &globalProvisionerClaims,
audiences: testAudiences,
claimer: claimer,
p := &SSHPOP{
Name: name,
Type: "SSHPOP",
Claims: &globalProvisionerClaims,
sshPubKeys: &SSHKeys{
UserKeys: []ssh.PublicKey{userKey},
HostKeys: []ssh.PublicKey{hostKey},
},
}, nil
}
p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences,
})
return p, err
}
func generateX5C(root []byte) (*X5C, error) {
@ -283,11 +279,6 @@ M46l92gdOozT
if err != nil {
return nil, err
}
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
rootPool := x509.NewCertPool()
var (
@ -305,15 +296,17 @@ M46l92gdOozT
}
rootPool.AddCert(cert)
}
return &X5C{
Name: name,
Type: "X5C",
Roots: root,
Claims: &globalProvisionerClaims,
audiences: testAudiences,
claimer: claimer,
rootPool: rootPool,
}, nil
p := &X5C{
Name: name,
Type: "X5C",
Roots: root,
Claims: &globalProvisionerClaims,
rootPool: rootPool,
}
p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences,
})
return p, err
}
func generateOIDC() (*OIDC, error) {
@ -333,11 +326,7 @@ func generateOIDC() (*OIDC, error) {
if err != nil {
return nil, err
}
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
return &OIDC{
p := &OIDC{
Name: name,
Type: "OIDC",
ClientID: clientID,
@ -351,8 +340,11 @@ func generateOIDC() (*OIDC, error) {
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
expiry: time.Now().Add(24 * time.Hour),
},
claimer: claimer,
}, nil
}
p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences,
})
return p, err
}
func generateGCP() (*GCP, error) {
@ -368,23 +360,21 @@ func generateGCP() (*GCP, error) {
if err != nil {
return nil, err
}
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
return &GCP{
p := &GCP{
Type: "GCP",
Name: name,
ServiceAccounts: []string{serviceAccount},
Claims: &globalProvisionerClaims,
claimer: claimer,
config: newGCPConfig(),
keyStore: &keyStore{
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
expiry: time.Now().Add(24 * time.Hour),
},
audiences: testAudiences.WithFragment("gcp/" + name),
}, nil
}
p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences.WithFragment("gcp/" + name),
})
return p, err
}
func generateAWS() (*AWS, error) {
@ -396,10 +386,6 @@ func generateAWS() (*AWS, error) {
if err != nil {
return nil, err
}
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
block, _ := pem.Decode([]byte(awsTestCertificate))
if block == nil || block.Type != "CERTIFICATE" {
return nil, errors.New("error decoding AWS certificate")
@ -408,13 +394,12 @@ func generateAWS() (*AWS, error) {
if err != nil {
return nil, errors.Wrap(err, "error parsing AWS certificate")
}
return &AWS{
p := &AWS{
Type: "AWS",
Name: name,
Accounts: []string{accountID},
Claims: &globalProvisionerClaims,
IMDSVersions: []string{"v2", "v1"},
claimer: claimer,
config: &awsConfig{
identityURL: awsIdentityURL,
signatureURL: awsSignatureURL,
@ -423,8 +408,11 @@ func generateAWS() (*AWS, error) {
certificates: []*x509.Certificate{cert},
signatureAlgorithm: awsSignatureAlgorithm,
},
audiences: testAudiences.WithFragment("aws/" + name),
}, nil
}
p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences.WithFragment("aws/" + name),
})
return p, err
}
func generateAWSWithServer() (*AWS, *httptest.Server, error) {
@ -505,10 +493,6 @@ func generateAWSV1Only() (*AWS, error) {
if err != nil {
return nil, err
}
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
block, _ := pem.Decode([]byte(awsTestCertificate))
if block == nil || block.Type != "CERTIFICATE" {
return nil, errors.New("error decoding AWS certificate")
@ -517,13 +501,12 @@ func generateAWSV1Only() (*AWS, error) {
if err != nil {
return nil, errors.Wrap(err, "error parsing AWS certificate")
}
return &AWS{
p := &AWS{
Type: "AWS",
Name: name,
Accounts: []string{accountID},
Claims: &globalProvisionerClaims,
IMDSVersions: []string{"v1"},
claimer: claimer,
config: &awsConfig{
identityURL: awsIdentityURL,
signatureURL: awsSignatureURL,
@ -532,8 +515,11 @@ func generateAWSV1Only() (*AWS, error) {
certificates: []*x509.Certificate{cert},
signatureAlgorithm: awsSignatureAlgorithm,
},
audiences: testAudiences.WithFragment("aws/" + name),
}, nil
}
p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences.WithFragment("aws/" + name),
})
return p, err
}
func generateAWSWithServerV1Only() (*AWS, *httptest.Server, error) {
@ -600,21 +586,16 @@ func generateAzure() (*Azure, error) {
if err != nil {
return nil, err
}
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
jwk, err := generateJSONWebKey()
if err != nil {
return nil, err
}
return &Azure{
p := &Azure{
Type: "Azure",
Name: name,
TenantID: tenantID,
Audience: azureDefaultAudience,
Claims: &globalProvisionerClaims,
claimer: claimer,
config: newAzureConfig(tenantID),
oidcConfig: openIDConfiguration{
Issuer: "https://sts.windows.net/" + tenantID + "/",
@ -624,7 +605,11 @@ func generateAzure() (*Azure, error) {
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
expiry: time.Now().Add(24 * time.Hour),
},
}, nil
}
p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences,
})
return p, err
}
func generateAzureWithServer() (*Azure, *httptest.Server, error) {
@ -671,7 +656,7 @@ func generateAzureWithServer() (*Azure, *httptest.Server, error) {
w.Header().Add("Cache-Control", "max-age=5")
writeJSON(w, getPublic(az.keyStore.keySet))
case "/metadata/identity/oauth2/token":
tok, err := generateAzureToken("subject", issuer, "https://management.azure.com/", az.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", time.Now(), &az.keyStore.keySet.Keys[0])
tok, err := generateAzureToken("subject", issuer, "https://management.azure.com/", az.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &az.keyStore.keySet.Keys[0])
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
} else {
@ -1009,7 +994,7 @@ func generateAWSToken(p *AWS, sub, iss, aud, accountID, instanceID, privateIP, r
return jose.Signed(sig).Claims(claims).CompactSerialize()
}
func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, virtualMachine string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, resourceName, resourceType string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
sig, err := jose.NewSigner(
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
@ -1017,6 +1002,12 @@ func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup,
if err != nil {
return "", err
}
var xmsMirID string
if resourceType == "vm" {
xmsMirID = fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s", subscriptionID, resourceGroup, resourceName)
} else if resourceType == "uai" {
xmsMirID = fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.ManagedIdentity/userAssignedIdentities/%s", subscriptionID, resourceGroup, resourceName)
}
claims := azurePayload{
Claims: jose.Claims{
@ -1034,7 +1025,7 @@ func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup,
ObjectID: "the-oid",
TenantID: tenantID,
Version: "the-version",
XMSMirID: fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s", subscriptionID, resourceGroup, virtualMachine),
XMSMirID: xmsMirID,
}
return jose.Signed(sig).Claims(claims).CompactSerialize()
}

@ -26,15 +26,14 @@ type x5cPayload struct {
// signature requests.
type X5C struct {
*base
ID string `json:"-"`
Type string `json:"type"`
Name string `json:"name"`
Roots []byte `json:"roots"`
Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"`
claimer *Claimer
audiences Audiences
rootPool *x509.CertPool
ID string `json:"-"`
Type string `json:"type"`
Name string `json:"name"`
Roots []byte `json:"roots"`
Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"`
ctl *Controller
rootPool *x509.CertPool
}
// GetID returns the provisioner unique identifier. The name and credential id
@ -86,7 +85,7 @@ func (p *X5C) GetEncryptedKey() (string, string, bool) {
}
// Init initializes and validates the fields of a X5C type.
func (p *X5C) Init(config Config) error {
func (p *X5C) Init(config Config) (err error) {
switch {
case p.Type == "":
return errors.New("provisioner type cannot be empty")
@ -101,6 +100,7 @@ func (p *X5C) Init(config Config) error {
var (
block *pem.Block
rest = p.Roots
count int
)
for rest != nil {
block, rest = pem.Decode(rest)
@ -111,22 +111,18 @@ func (p *X5C) Init(config Config) error {
if err != nil {
return errors.Wrap(err, "error parsing x509 certificate from PEM block")
}
count++
p.rootPool.AddCert(cert)
}
// Verify that at least one root was found.
if len(p.rootPool.Subjects()) == 0 {
if count == 0 {
return errors.Errorf("no x509 certificates found in roots attribute for provisioner '%s'", p.GetName())
}
// Update claims with global ones
var err error
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
return nil
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.ctl, err = NewController(p, p.Claims, config)
return
}
// authorizeToken performs common jwt authorization actions and returns the
@ -189,13 +185,13 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err
// AuthorizeRevoke returns an error if the provisioner does not have rights to
// revoke the certificate with serial number in the `sub` property.
func (p *X5C) AuthorizeRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.audiences.Revoke)
_, err := p.authorizeToken(token, p.ctl.Audiences.Revoke)
return errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeRevoke")
}
// AuthorizeSign validates the given token.
func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token, p.audiences.Sign)
claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSign")
}
@ -213,40 +209,45 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
data.SetToken(v)
}
// The X509 certificate will be available using the template variable
// AuthorizationCrt. For example {{ .AuthorizationCrt.DNSNames }} can be
// used to get all the domains.
data.SetAuthorizationCertificate(claims.chains[0][0])
templateOptions, err := TemplateOptions(p.Options, data)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign")
}
return []SignOption{
p,
templateOptions,
// modifiers / withOptions
newProvisionerExtensionOption(TypeX5C, p.Name, ""),
profileLimitDuration{p.claimer.DefaultTLSCertDuration(),
claims.chains[0][0].NotBefore, claims.chains[0][0].NotAfter},
profileLimitDuration{
p.ctl.Claimer.DefaultTLSCertDuration(),
claims.chains[0][0].NotBefore, claims.chains[0][0].NotAfter,
},
// validators
commonNameValidator(claims.Subject),
defaultSANsValidator(claims.SANs),
defaultPublicKeyValidator{},
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
}, nil
}
// AuthorizeRenew returns an error if the renewal is disabled.
func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() {
return errs.Unauthorized("x5c.AuthorizeRenew; renew is disabled for x5c provisioner '%s'", p.GetName())
}
return nil
return p.ctl.AuthorizeRenew(ctx, cert)
}
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() {
if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName())
}
claims, err := p.authorizeToken(token, p.audiences.SSHSign)
claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign")
}
@ -287,6 +288,11 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
data.SetToken(v)
}
// The X509 certificate will be available using the template variable
// AuthorizationCrt. For example {{ .AuthorizationCrt.DNSNames }} can be
// used to get all the domains.
data.SetAuthorizationCertificate(claims.chains[0][0])
templateOptions, err := TemplateSSHOptions(p.Options, data)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign")
@ -304,11 +310,11 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
return append(signOptions,
// Checks the validity bounds, and set the validity if has not been set.
&sshLimitDuration{p.claimer, claims.chains[0][0].NotAfter},
&sshLimitDuration{p.ctl.Claimer, claims.chains[0][0].NotAfter},
// Validate public key.
&sshDefaultPublicKeyValidator{},
// Validate the validity period.
&sshCertValidityValidator{p.claimer},
&sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate
&sshCertDefaultValidator{},
), nil

@ -2,16 +2,19 @@ package provisioner
import (
"context"
"crypto/x509"
"errors"
"fmt"
"net/http"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil"
"go.step.sm/crypto/randutil"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
)
func TestX5C_Getters(t *testing.T) {
@ -69,8 +72,8 @@ func TestX5C_Init(t *testing.T) {
},
"fail/no-valid-root-certs": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{
p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo"), audiences: testAudiences},
err: errors.Errorf("no x509 certificates found in roots attribute for provisioner 'foo'"),
p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo")},
err: errors.New("no x509 certificates found in roots attribute for provisioner 'foo'"),
}
},
"fail/invalid-duration": func(t *testing.T) ProvisionerValidateTest {
@ -117,9 +120,11 @@ M46l92gdOozT
return ProvisionerValidateTest{
p: p,
extraValid: func(p *X5C) error {
// nolint:staticcheck // We don't have a different way to
// check the number of certificates in the pool.
numCerts := len(p.rootPool.Subjects())
if numCerts != 2 {
return errors.Errorf("unexpected number of certs: want 2, but got %d", numCerts)
return fmt.Errorf("unexpected number of certs: want 2, but got %d", numCerts)
}
return nil
},
@ -141,7 +146,7 @@ M46l92gdOozT
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, tc.p.audiences, config.Audiences.WithFragment(tc.p.GetID()))
assert.Equals(t, *tc.p.ctl.Audiences, config.Audiences.WithFragment(tc.p.GetID()))
if tc.extraValid != nil {
assert.Nil(t, tc.extraValid(tc.p))
}
@ -384,8 +389,8 @@ lgsqsR63is+0YQ==
tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
@ -455,7 +460,7 @@ func TestX5C_AuthorizeSign(t *testing.T) {
tc := tt(t)
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -463,19 +468,20 @@ func TestX5C_AuthorizeSign(t *testing.T) {
} else {
if assert.Nil(t, tc.err) {
if assert.NotNil(t, opts) {
assert.Equals(t, len(opts), 7)
assert.Equals(t, len(opts), 8)
for _, o := range opts {
switch v := o.(type) {
case *X5C:
case certificateOptionsFunc:
case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeX5C))
assert.Equals(t, v.Type, TypeX5C)
assert.Equals(t, v.Name, tc.p.GetName())
assert.Equals(t, v.CredentialID, "")
assert.Len(t, 0, v.KeyValuePairs)
case profileLimitDuration:
assert.Equals(t, v.def, tc.p.claimer.DefaultTLSCertDuration())
assert.Equals(t, v.def, tc.p.ctl.Claimer.DefaultTLSCertDuration())
claims, err := tc.p.authorizeToken(tc.token, tc.p.audiences.Sign)
claims, err := tc.p.authorizeToken(tc.token, tc.p.ctl.Audiences.Sign)
assert.FatalError(t, err)
assert.Equals(t, v.notAfter, claims.chains[0][0].NotAfter)
case commonNameValidator:
@ -484,10 +490,10 @@ func TestX5C_AuthorizeSign(t *testing.T) {
case defaultSANsValidator:
assert.Equals(t, []string(v), tc.sans)
case *validityValidator:
assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration())
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}
}
}
@ -538,8 +544,8 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
tc := tt(t)
if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
@ -551,6 +557,7 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
}
func TestX5C_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
type test struct {
p *X5C
code int
@ -563,12 +570,12 @@ func TestX5C_AuthorizeRenew(t *testing.T) {
// disable renewal
disable := true
p.Claims = &Claims{DisableRenewal: &disable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
return test{
p: p,
code: http.StatusUnauthorized,
err: errors.Errorf("x5c.AuthorizeRenew; renew is disabled for x5c provisioner '%s'", p.GetName()),
err: fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()),
}
},
"ok": func(t *testing.T) test {
@ -582,10 +589,13 @@ func TestX5C_AuthorizeRenew(t *testing.T) {
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if err := tc.p.AuthorizeRenew(context.Background(), nil); err != nil {
if err := tc.p.AuthorizeRenew(context.Background(), &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
@ -618,13 +628,13 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
// disable sshCA
enable := false
p.Claims = &Claims{EnableSSHCA: &enable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
code: http.StatusUnauthorized,
err: errors.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()),
err: fmt.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()),
}
},
"fail/invalid-token": func(t *testing.T) test {
@ -745,7 +755,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
tc := tt(t)
if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -774,13 +784,13 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
case sshCertDefaultsModifier:
assert.Equals(t, SignSSHOptions(v), SignSSHOptions{CertType: SSHUserCert})
case *sshLimitDuration:
assert.Equals(t, v.Claimer, tc.p.claimer)
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter)
case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.claimer)
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
case *sshDefaultPublicKeyValidator, *sshCertDefaultValidator, sshCertificateOptionsFunc:
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}
tot++
}

@ -87,20 +87,20 @@ func (a *Authority) LoadProvisionerByName(name string) (provisioner.Interface, e
return p, nil
}
func (a *Authority) generateProvisionerConfig(ctx context.Context) (*provisioner.Config, error) {
func (a *Authority) generateProvisionerConfig(ctx context.Context) (provisioner.Config, error) {
// Merge global and configuration claims
claimer, err := provisioner.NewClaimer(a.config.AuthorityConfig.Claims, config.GlobalProvisionerClaims)
if err != nil {
return nil, err
return provisioner.Config{}, err
}
// TODO: should we also be combining the ssh federated roots here?
// If we rotate ssh roots keys, sshpop provisioner will lose ability to
// validate old SSH certificates, unless they are added as federated certs.
sshKeys, err := a.GetSSHRoots(ctx)
if err != nil {
return nil, err
return provisioner.Config{}, err
}
return &provisioner.Config{
return provisioner.Config{
Claims: claimer.Claims(),
Audiences: a.config.GetAudiences(),
DB: a.db,
@ -108,7 +108,9 @@ func (a *Authority) generateProvisionerConfig(ctx context.Context) (*provisioner
UserKeys: sshKeys.UserKeys,
HostKeys: sshKeys.HostKeys,
},
GetIdentityFunc: a.getIdentityFunc,
GetIdentityFunc: a.getIdentityFunc,
AuthorizeRenewFunc: a.authorizeRenewFunc,
AuthorizeSSHRenewFunc: a.authorizeSSHRenewFunc,
}, nil
}
@ -133,9 +135,18 @@ func (a *Authority) StoreProvisioner(ctx context.Context, prov *linkedca.Provisi
"provisioner with token ID %s already exists", certProv.GetIDForToken())
}
provisionerConfig, err := a.generateProvisionerConfig(ctx)
if err != nil {
return admin.WrapErrorISE(err, "error generating provisioner config")
}
if err := certProv.Init(provisionerConfig); err != nil {
return admin.WrapError(admin.ErrorBadRequestType, err, "error validating configuration for provisioner %s", prov.Name)
}
// Store to database -- this will set the ID.
if err := a.adminDB.CreateProvisioner(ctx, prov); err != nil {
return admin.WrapErrorISE(err, "error creating admin")
return admin.WrapErrorISE(err, "error creating provisioner")
}
// We need a new conversion that has the newly set ID.
@ -145,12 +156,7 @@ func (a *Authority) StoreProvisioner(ctx context.Context, prov *linkedca.Provisi
"error converting to certificates provisioner from linkedca provisioner")
}
provisionerConfig, err := a.generateProvisionerConfig(ctx)
if err != nil {
return admin.WrapErrorISE(err, "error generating provisioner config")
}
if err := certProv.Init(*provisionerConfig); err != nil {
if err := certProv.Init(provisionerConfig); err != nil {
return admin.WrapErrorISE(err, "error initializing provisioner %s", prov.Name)
}
@ -179,7 +185,7 @@ func (a *Authority) UpdateProvisioner(ctx context.Context, nu *linkedca.Provisio
return admin.WrapErrorISE(err, "error generating provisioner config")
}
if err := certProv.Init(*provisionerConfig); err != nil {
if err := certProv.Init(provisionerConfig); err != nil {
return admin.WrapErrorISE(err, "error initializing provisioner %s", nu.Name)
}
@ -431,7 +437,8 @@ func claimsToCertificates(c *linkedca.Claims) (*provisioner.Claims, error) {
}
pc := &provisioner.Claims{
DisableRenewal: &c.DisableRenewal,
DisableRenewal: &c.DisableRenewal,
AllowRenewAfterExpiry: &c.AllowRenewAfterExpiry,
}
var err error
@ -469,12 +476,18 @@ func claimsToLinkedca(c *provisioner.Claims) *linkedca.Claims {
}
disableRenewal := config.DefaultDisableRenewal
allowRenewAfterExpiry := config.DefaultAllowRenewAfterExpiry
if c.DisableRenewal != nil {
disableRenewal = *c.DisableRenewal
}
if c.AllowRenewAfterExpiry != nil {
allowRenewAfterExpiry = *c.AllowRenewAfterExpiry
}
lc := &linkedca.Claims{
DisableRenewal: disableRenewal,
DisableRenewal: disableRenewal,
AllowRenewAfterExpiry: allowRenewAfterExpiry,
}
if c.DefaultTLSDur != nil || c.MinTLSDur != nil || c.MaxTLSDur != nil {
@ -706,6 +719,8 @@ func ProvisionerToCertificates(p *linkedca.Provisioner) (provisioner.Interface,
Name: p.Name,
TenantID: cfg.TenantId,
ResourceGroups: cfg.ResourceGroups,
SubscriptionIDs: cfg.SubscriptionIds,
ObjectIDs: cfg.ObjectIds,
Audience: cfg.Audience,
DisableCustomSANs: cfg.DisableCustomSans,
DisableTrustOnFirstUse: cfg.DisableTrustOnFirstUse,
@ -865,6 +880,8 @@ func ProvisionerToLinkedca(p provisioner.Interface) (*linkedca.Provisioner, erro
Azure: &linkedca.AzureProvisioner{
TenantId: p.TenantID,
ResourceGroups: p.ResourceGroups,
SubscriptionIds: p.SubscriptionIDs,
ObjectIds: p.ObjectIDs,
Audience: p.Audience,
DisableCustomSans: p.DisableCustomSANs,
DisableTrustOnFirstUse: p.DisableTrustOnFirstUse,

@ -1,13 +1,13 @@
package authority
import (
"errors"
"net/http"
"testing"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
)
func TestGetEncryptedKey(t *testing.T) {
@ -49,8 +49,8 @@ func TestGetEncryptedKey(t *testing.T) {
ek, err := tc.a.GetEncryptedKey(tc.kid)
if err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
@ -90,8 +90,8 @@ func TestGetProvisioners(t *testing.T) {
ps, next, err := tc.a.GetProvisioners("", 0)
if err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}

@ -2,14 +2,15 @@ package authority
import (
"crypto/x509"
"errors"
"net/http"
"reflect"
"testing"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/pemutil"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
)
func TestRoot(t *testing.T) {
@ -31,7 +32,7 @@ func TestRoot(t *testing.T) {
crt, err := a.Root(tc.sum)
if err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())

@ -7,21 +7,22 @@ import (
"crypto/rand"
"crypto/x509"
"encoding/base64"
"errors"
"fmt"
"net/http"
"reflect"
"testing"
"time"
"github.com/pkg/errors"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil"
"golang.org/x/crypto/ssh"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/templates"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil"
"golang.org/x/crypto/ssh"
)
type sshTestModifier ssh.Certificate
@ -716,8 +717,8 @@ func TestAuthority_GetSSHBastion(t *testing.T) {
t.Errorf("Authority.GetSSHBastion() error = %v, wantErr %v", err, tt.wantErr)
return
} else if err != nil {
_, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
_, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Authority.GetSSHBastion() = %v, want %v", got, tt.want)
@ -806,8 +807,8 @@ func TestAuthority_GetSSHHosts(t *testing.T) {
hosts, err := auth.GetSSHHosts(context.Background(), tc.cert)
if err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
@ -1033,8 +1034,8 @@ func TestAuthority_RekeySSH(t *testing.T) {
cert, err := auth.RekeySSH(context.Background(), tc.cert, tc.key, tc.signOpts...)
if err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
}

@ -1,11 +0,0 @@
package status
// Type is the type for status.
type Type string
var (
// Active active
Active = Type("active")
// Deleted deleted
Deleted = Type("deleted")
)

@ -89,8 +89,13 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
// Set backdate with the configured value
signOpts.Backdate = a.config.AuthorityConfig.Backdate.Duration
var prov provisioner.Interface
for _, op := range extraOpts {
switch k := op.(type) {
// Capture current provisioner
case provisioner.Interface:
prov = k
// Adds new options to NewCertificate
case provisioner.CertificateOptions:
certOptions = append(certOptions, k.Options(signOpts)...)
@ -204,7 +209,7 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
}
fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...)
if err = a.storeCertificate(fullchain); err != nil {
if err = a.storeCertificate(prov, fullchain); err != nil {
if err != db.ErrNotImplemented {
return nil, errs.Wrap(http.StatusInternalServerError, err,
"authority.Sign; error storing certificate in db", opts...)
@ -325,19 +330,28 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5
// TODO: at some point we should replace the db.AuthDB interface to implement
// `StoreCertificate(...*x509.Certificate) error` instead of just
// `StoreCertificate(*x509.Certificate) error`.
func (a *Authority) storeCertificate(fullchain []*x509.Certificate) error {
func (a *Authority) storeCertificate(prov provisioner.Interface, fullchain []*x509.Certificate) error {
type linkedChainStorer interface {
StoreCertificateChain(provisioner.Interface, ...*x509.Certificate) error
}
type certificateChainStorer interface {
StoreCertificateChain(...*x509.Certificate) error
}
// Store certificate in linkedca
if s, ok := a.adminDB.(certificateChainStorer); ok {
switch s := a.adminDB.(type) {
case linkedChainStorer:
return s.StoreCertificateChain(prov, fullchain...)
case certificateChainStorer:
return s.StoreCertificateChain(fullchain...)
}
// Store certificate in local db
if s, ok := a.db.(certificateChainStorer); ok {
switch s := a.db.(type) {
case certificateChainStorer:
return s.StoreCertificateChain(fullchain...)
default:
return a.db.StoreCertificate(fullchain[0])
}
return a.db.StoreCertificate(fullchain[0])
}
// storeRenewedCertificate allows to use an extension of the db.AuthDB interface

@ -11,24 +11,26 @@ import (
"crypto/x509/pkix"
"encoding/asn1"
"encoding/pem"
"errors"
"fmt"
"net/http"
"reflect"
"testing"
"time"
"github.com/smallstep/certificates/cas/softcas"
"gopkg.in/square/go-jose.v2/jwt"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/pemutil"
"go.step.sm/crypto/x509util"
"gopkg.in/square/go-jose.v2/jwt"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/cas/softcas"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
)
var (
@ -187,14 +189,14 @@ func setExtraExtsCSR(exts []pkix.Extension) func(*x509.CertificateRequest) {
func generateSubjectKeyID(pub crypto.PublicKey) ([]byte, error) {
b, err := x509.MarshalPKIXPublicKey(pub)
if err != nil {
return nil, errors.Wrap(err, "error marshaling public key")
return nil, fmt.Errorf("error marshaling public key: %w", err)
}
info := struct {
Algorithm pkix.AlgorithmIdentifier
SubjectPublicKey asn1.BitString
}{}
if _, err = asn1.Unmarshal(b, &info); err != nil {
return nil, errors.Wrap(err, "error unmarshaling public key")
return nil, fmt.Errorf("error unmarshaling public key: %w", err)
}
hash := sha1.Sum(info.SubjectPublicKey.Bytes)
return hash[:], nil
@ -661,8 +663,8 @@ ZYtQ9Ot36qc=
if err != nil {
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
assert.Nil(t, certChain)
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -757,7 +759,7 @@ func TestAuthority_Renew(t *testing.T) {
now := time.Now().UTC()
nb1 := now.Add(-time.Minute * 7)
na1 := now
na1 := now.Add(time.Hour)
so := &provisioner.SignOptions{
NotBefore: provisioner.NewTimeDuration(nb1),
NotAfter: provisioner.NewTimeDuration(na1),
@ -798,7 +800,20 @@ func TestAuthority_Renew(t *testing.T) {
"fail/unauthorized": func() (*renewTest, error) {
return &renewTest{
cert: certNoRenew,
err: errors.New("authority.Rekey: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'dev'"),
err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
code: http.StatusUnauthorized,
}, nil
},
"fail/WithAuthorizeRenewFunc": func() (*renewTest, error) {
aa := testAuthority(t, WithAuthorizeRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error {
return errs.Unauthorized("not authorized")
}))
aa.x509CAService = a.x509CAService
aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template
return &renewTest{
auth: aa,
cert: cert,
err: errors.New("authority.Rekey: authority.authorizeRenew: not authorized"),
code: http.StatusUnauthorized,
}, nil
},
@ -820,6 +835,17 @@ func TestAuthority_Renew(t *testing.T) {
cert: cert,
}, nil
},
"ok/WithAuthorizeRenewFunc": func() (*renewTest, error) {
aa := testAuthority(t, WithAuthorizeRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error {
return nil
}))
aa.x509CAService = a.x509CAService
aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template
return &renewTest{
auth: aa,
cert: cert,
}, nil
},
}
for name, genTestCase := range tests {
@ -836,8 +862,8 @@ func TestAuthority_Renew(t *testing.T) {
if err != nil {
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
assert.Nil(t, certChain)
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -856,7 +882,7 @@ func TestAuthority_Renew(t *testing.T) {
expiry := now.Add(time.Minute * 7)
assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute)))
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute)))
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour)))
tmplt := a.config.AuthorityConfig.Template
assert.Equals(t, leaf.Subject.String(),
@ -956,7 +982,7 @@ func TestAuthority_Rekey(t *testing.T) {
now := time.Now().UTC()
nb1 := now.Add(-time.Minute * 7)
na1 := now
na1 := now.Add(time.Hour)
so := &provisioner.SignOptions{
NotBefore: provisioner.NewTimeDuration(nb1),
NotAfter: provisioner.NewTimeDuration(na1),
@ -998,7 +1024,7 @@ func TestAuthority_Rekey(t *testing.T) {
"fail/unauthorized": func() (*renewTest, error) {
return &renewTest{
cert: certNoRenew,
err: errors.New("authority.Rekey: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'dev'"),
err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
code: http.StatusUnauthorized,
}, nil
},
@ -1043,8 +1069,8 @@ func TestAuthority_Rekey(t *testing.T) {
if err != nil {
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
assert.Nil(t, certChain)
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -1063,7 +1089,7 @@ func TestAuthority_Rekey(t *testing.T) {
expiry := now.Add(time.Minute * 7)
assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute)))
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute)))
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour)))
tmplt := a.config.AuthorityConfig.Template
assert.Equals(t, leaf.Subject.String(),
@ -1432,8 +1458,8 @@ func TestAuthority_Revoke(t *testing.T) {
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod)
if err := tc.auth.Revoke(ctx, tc.opts); err != nil {
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())

@ -12,12 +12,13 @@ import (
"time"
"github.com/pkg/errors"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
acmeAPI "github.com/smallstep/certificates/acme/api"
"github.com/smallstep/certificates/api"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil"
"github.com/smallstep/certificates/api/render"
)
func TestNewACMEClient(t *testing.T) {
@ -112,15 +113,15 @@ func TestNewACMEClient(t *testing.T) {
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header
switch {
case i == 0:
api.JSONStatus(w, tc.r1, tc.rc1)
render.JSONStatus(w, tc.r1, tc.rc1)
i++
case i == 1:
w.Header().Set("Replay-Nonce", "abc123")
api.JSONStatus(w, []byte{}, 200)
render.JSONStatus(w, []byte{}, 200)
i++
default:
w.Header().Set("Location", accLocation)
api.JSONStatus(w, tc.r2, tc.rc2)
render.JSONStatus(w, tc.r2, tc.rc2)
}
})
@ -206,7 +207,7 @@ func TestACMEClient_GetNonce(t *testing.T) {
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header
w.Header().Set("Replay-Nonce", expectedNonce)
api.JSONStatus(w, tc.r1, tc.rc1)
render.JSONStatus(w, tc.r1, tc.rc1)
})
if nonce, err := ac.GetNonce(); err != nil {
@ -315,7 +316,7 @@ func TestACMEClient_post(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1)
render.JSONStatus(w, tc.r1, tc.rc1)
i++
return
}
@ -338,7 +339,7 @@ func TestACMEClient_post(t *testing.T) {
assert.Equals(t, hdr.KeyID, ac.kid)
}
api.JSONStatus(w, tc.r2, tc.rc2)
render.JSONStatus(w, tc.r2, tc.rc2)
})
if resp, err := tc.client.post(tc.payload, url, tc.ops...); err != nil {
@ -455,7 +456,7 @@ func TestACMEClient_NewOrder(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1)
render.JSONStatus(w, tc.r1, tc.rc1)
i++
return
}
@ -477,7 +478,7 @@ func TestACMEClient_NewOrder(t *testing.T) {
assert.FatalError(t, err)
assert.Equals(t, payload, norb)
api.JSONStatus(w, tc.r2, tc.rc2)
render.JSONStatus(w, tc.r2, tc.rc2)
})
if res, err := ac.NewOrder(norb); err != nil {
@ -577,7 +578,7 @@ func TestACMEClient_GetOrder(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1)
render.JSONStatus(w, tc.r1, tc.rc1)
i++
return
}
@ -599,7 +600,7 @@ func TestACMEClient_GetOrder(t *testing.T) {
assert.FatalError(t, err)
assert.Equals(t, len(payload), 0)
api.JSONStatus(w, tc.r2, tc.rc2)
render.JSONStatus(w, tc.r2, tc.rc2)
})
if res, err := ac.GetOrder(url); err != nil {
@ -699,7 +700,7 @@ func TestACMEClient_GetAuthz(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1)
render.JSONStatus(w, tc.r1, tc.rc1)
i++
return
}
@ -721,7 +722,7 @@ func TestACMEClient_GetAuthz(t *testing.T) {
assert.FatalError(t, err)
assert.Equals(t, len(payload), 0)
api.JSONStatus(w, tc.r2, tc.rc2)
render.JSONStatus(w, tc.r2, tc.rc2)
})
if res, err := ac.GetAuthz(url); err != nil {
@ -821,7 +822,7 @@ func TestACMEClient_GetChallenge(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1)
render.JSONStatus(w, tc.r1, tc.rc1)
i++
return
}
@ -844,7 +845,7 @@ func TestACMEClient_GetChallenge(t *testing.T) {
assert.Equals(t, len(payload), 0)
api.JSONStatus(w, tc.r2, tc.rc2)
render.JSONStatus(w, tc.r2, tc.rc2)
})
if res, err := ac.GetChallenge(url); err != nil {
@ -944,7 +945,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1)
render.JSONStatus(w, tc.r1, tc.rc1)
i++
return
}
@ -967,7 +968,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) {
assert.Equals(t, payload, []byte("{}"))
api.JSONStatus(w, tc.r2, tc.rc2)
render.JSONStatus(w, tc.r2, tc.rc2)
})
if err := ac.ValidateChallenge(url); err != nil {
@ -1071,7 +1072,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1)
render.JSONStatus(w, tc.r1, tc.rc1)
i++
return
}
@ -1093,7 +1094,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
assert.FatalError(t, err)
assert.Equals(t, payload, frb)
api.JSONStatus(w, tc.r2, tc.rc2)
render.JSONStatus(w, tc.r2, tc.rc2)
})
if err := ac.FinalizeOrder(url, csr); err != nil {
@ -1200,7 +1201,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1)
render.JSONStatus(w, tc.r1, tc.rc1)
i++
return
}
@ -1222,7 +1223,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
assert.FatalError(t, err)
assert.Equals(t, len(payload), 0)
api.JSONStatus(w, tc.r2, tc.rc2)
render.JSONStatus(w, tc.r2, tc.rc2)
})
if res, err := tc.client.GetAccountOrders(); err != nil {
@ -1331,7 +1332,7 @@ func TestACMEClient_GetCertificate(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1)
render.JSONStatus(w, tc.r1, tc.rc1)
i++
return
}
@ -1356,7 +1357,7 @@ func TestACMEClient_GetCertificate(t *testing.T) {
if tc.certBytes != nil {
w.Write(tc.certBytes)
} else {
api.JSONStatus(w, tc.r2, tc.rc2)
render.JSONStatus(w, tc.r2, tc.rc2)
}
})

@ -14,11 +14,14 @@ import (
"time"
"github.com/pkg/errors"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/randutil"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/randutil"
)
func newLocalListener() net.Listener {
@ -79,7 +82,7 @@ func startCAServer(configFile string) (*CA, string, error) {
func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/version" {
api.JSON(w, api.VersionResponse{
render.JSON(w, api.VersionResponse{
Version: "test",
RequireClientAuthentication: true,
})
@ -93,7 +96,7 @@ func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Han
}
isMTLS := r.TLS != nil && len(r.TLS.PeerCertificates) > 0
if !isMTLS {
api.WriteError(w, errs.Unauthorized("missing peer certificate"))
render.Error(w, errs.Unauthorized("missing peer certificate"))
} else {
next.ServeHTTP(w, r)
}
@ -408,6 +411,7 @@ func TestBootstrapClientServerRotation(t *testing.T) {
server.ServeTLS(listener, "", "")
}()
defer server.Close()
time.Sleep(1 * time.Second)
// Create bootstrap client
token = generateBootstrapToken(caURL, "client", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
@ -419,7 +423,6 @@ func TestBootstrapClientServerRotation(t *testing.T) {
// doTest does a request that requires mTLS
doTest := func(client *http.Client) error {
time.Sleep(1 * time.Second)
// test with ca
resp, err := client.Post(caURL+"/renew", "application/json", http.NoBody)
if err != nil {

@ -8,6 +8,7 @@ import (
"net/http"
"net/url"
"reflect"
"strings"
"sync"
"github.com/go-chi/chi"
@ -26,11 +27,14 @@ import (
scepAPI "github.com/smallstep/certificates/scep/api"
"github.com/smallstep/certificates/server"
"github.com/smallstep/nosql"
"go.step.sm/cli-utils/step"
"go.step.sm/crypto/x509util"
)
type options struct {
configFile string
linkedCAToken string
quiet bool
password []byte
issuerPassword []byte
sshHostPassword []byte
@ -101,6 +105,13 @@ func WithLinkedCAToken(token string) Option {
}
}
// WithQuiet sets the quiet flag.
func WithQuiet(quiet bool) Option {
return func(o *options) {
o.quiet = quiet
}
}
// CA is the type used to build the complete certificate authority. It builds
// the HTTP server, set ups the middlewares and the HTTP handlers.
type CA struct {
@ -288,6 +299,35 @@ func (ca *CA) Run() error {
var wg sync.WaitGroup
errs := make(chan error, 1)
if !ca.opts.quiet {
authorityInfo := ca.auth.GetInfo()
log.Printf("Starting %s", step.Version())
log.Printf("Documentation: https://u.step.sm/docs/ca")
log.Printf("Community Discord: https://u.step.sm/discord")
if step.Contexts().GetCurrent() != nil {
log.Printf("Current context: %s", step.Contexts().GetCurrent().Name)
}
log.Printf("Config file: %s", ca.opts.configFile)
baseURL := fmt.Sprintf("https://%s%s",
authorityInfo.DNSNames[0],
ca.config.Address[strings.LastIndex(ca.config.Address, ":"):])
log.Printf("The primary server URL is %s", baseURL)
log.Printf("Root certificates are available at %s/roots.pem", baseURL)
if len(authorityInfo.DNSNames) > 1 {
log.Printf("Additional configured hostnames: %s",
strings.Join(authorityInfo.DNSNames[1:], ", "))
}
for _, crt := range authorityInfo.RootX509Certs {
log.Printf("X.509 Root Fingerprint: %s", x509util.Fingerprint(crt))
}
if authorityInfo.SSHCAHostPublicKey != nil {
log.Printf("SSH Host CA Key is %s\n", authorityInfo.SSHCAHostPublicKey)
}
if authorityInfo.SSHCAUserPublicKey != nil {
log.Printf("SSH User CA Key: %s\n", authorityInfo.SSHCAUserPublicKey)
}
}
if ca.insecureSrv != nil {
wg.Add(1)
go func() {
@ -355,6 +395,7 @@ func (ca *CA) Reload() error {
WithSSHUserPassword(ca.opts.sshUserPassword),
WithIssuerPassword(ca.opts.issuerPassword),
WithLinkedCAToken(ca.opts.linkedCAToken),
WithQuiet(ca.opts.quiet),
WithConfigFile(ca.opts.configFile),
WithDatabase(ca.auth.GetDatabase()),
)
@ -450,9 +491,6 @@ func (ca *CA) getTLSConfig(auth *authority.Authority) (*tls.Config, error) {
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
tlsConfig.ClientCAs = certPool
// Use server's most preferred ciphersuite
tlsConfig.PreferServerCipherSuites = true
return tlsConfig, nil
}

@ -563,6 +563,11 @@ func (c *Client) retryOnError(r *http.Response) bool {
return false
}
// GetCaURL returns the configured CA url.
func (c *Client) GetCaURL() string {
return c.endpoint.String()
}
// GetRootCAs returns the RootCAs certificate pool from the configured
// transport.
func (c *Client) GetRootCAs() *x509.CertPool {
@ -723,6 +728,36 @@ retry:
return &sign, nil
}
// RenewWithToken performs the renew request to the CA with the given
// authorization token and returns the api.SignResponse struct. This method is
// generally used to renew an expired certificate.
func (c *Client) RenewWithToken(token string) (*api.SignResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"})
req, err := http.NewRequest("POST", u.String(), http.NoBody)
if err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; error creating request")
}
req.Header.Add("Authorization", "Bearer "+token)
retry:
resp, err := c.client.Do(req)
if err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; client POST %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var sign api.SignResponse
if err := readJSON(resp.Body, &sign); err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; error reading %s", u)
}
return &sign, nil
}
// Rekey performs the rekey request to the CA and returns the api.SignResponse
// struct.
func (c *Client) Rekey(req *api.RekeyRequest, tr http.RoundTripper) (*api.SignResponse, error) {

@ -8,6 +8,7 @@ import (
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"net/http"
"net/http/httptest"
@ -16,14 +17,16 @@ import (
"testing"
"time"
"github.com/pkg/errors"
"go.step.sm/crypto/x509util"
"golang.org/x/crypto/ssh"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/x509util"
"golang.org/x/crypto/ssh"
)
const (
@ -179,7 +182,7 @@ func TestClient_Version(t *testing.T) {
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.Version()
@ -229,7 +232,7 @@ func TestClient_Health(t *testing.T) {
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.Health()
@ -287,7 +290,7 @@ func TestClient_Root(t *testing.T) {
if req.RequestURI != expected {
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected)
}
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.Root(tt.shasum)
@ -354,10 +357,10 @@ func TestClient_Sign(t *testing.T) {
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body := new(api.SignRequest)
if err := api.ReadJSON(req.Body, body); err != nil {
if err := read.JSON(req.Body, body); err != nil {
e, ok := tt.response.(error)
assert.Fatal(t, ok, "response expected to be error type")
api.WriteError(w, e)
render.Error(w, e)
return
} else if !equalJSON(t, body, tt.request) {
if tt.request == nil {
@ -368,7 +371,7 @@ func TestClient_Sign(t *testing.T) {
t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request)
}
}
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.Sign(tt.request)
@ -426,10 +429,10 @@ func TestClient_Revoke(t *testing.T) {
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body := new(api.RevokeRequest)
if err := api.ReadJSON(req.Body, body); err != nil {
if err := read.JSON(req.Body, body); err != nil {
e, ok := tt.response.(error)
assert.Fatal(t, ok, "response expected to be error type")
api.WriteError(w, e)
render.Error(w, e)
return
} else if !equalJSON(t, body, tt.request) {
if tt.request == nil {
@ -440,7 +443,7 @@ func TestClient_Revoke(t *testing.T) {
t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request)
}
}
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.Revoke(tt.request, nil)
@ -500,7 +503,7 @@ func TestClient_Renew(t *testing.T) {
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.Renew(nil)
@ -516,8 +519,8 @@ func TestClient_Renew(t *testing.T) {
t.Errorf("Client.Renew() = %v, want nil", got)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, err.Error(), tt.err.Error())
default:
@ -529,6 +532,74 @@ func TestClient_Renew(t *testing.T) {
}
}
func TestClient_RenewWithToken(t *testing.T) {
ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
CertChainPEM: []api.Certificate{
{Certificate: parseCertificate(certPEM)},
{Certificate: parseCertificate(rootPEM)},
},
}
tests := []struct {
name string
response interface{}
responseCode int
wantErr bool
err error
}{
{"ok", ok, 200, false, nil},
{"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
{"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
}
srv := httptest.NewServer(nil)
defer srv.Close()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil {
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.Header.Get("Authorization") != "Bearer token" {
render.JSONStatus(w, errs.InternalServer("force"), 500)
} else {
render.JSONStatus(w, tt.response, tt.responseCode)
}
})
got, err := c.RenewWithToken("token")
if (err != nil) != tt.wantErr {
fmt.Printf("%+v", err)
t.Errorf("Client.RenewWithToken() error = %v, wantErr %v", err, tt.wantErr)
return
}
switch {
case err != nil:
if got != nil {
t.Errorf("Client.RenewWithToken() = %v, want nil", got)
}
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, err.Error(), tt.err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.RenewWithToken() = %v, want %v", got, tt.response)
}
}
})
}
}
func TestClient_Rekey(t *testing.T) {
ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
@ -569,7 +640,7 @@ func TestClient_Rekey(t *testing.T) {
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.Rekey(tt.request, nil)
@ -585,8 +656,8 @@ func TestClient_Rekey(t *testing.T) {
t.Errorf("Client.Renew() = %v, want nil", got)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, err.Error(), tt.err.Error())
default:
@ -634,7 +705,7 @@ func TestClient_Provisioners(t *testing.T) {
if req.RequestURI != tt.expectedURI {
t.Errorf("RequestURI = %s, want %s", req.RequestURI, tt.expectedURI)
}
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.Provisioners(tt.args...)
@ -691,7 +762,7 @@ func TestClient_ProvisionerKey(t *testing.T) {
if req.RequestURI != expected {
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected)
}
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.ProvisionerKey(tt.kid)
@ -706,8 +777,8 @@ func TestClient_ProvisionerKey(t *testing.T) {
t.Errorf("Client.ProvisionerKey() = %v, want nil", got)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
default:
@ -750,7 +821,7 @@ func TestClient_Roots(t *testing.T) {
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.Roots()
@ -765,8 +836,8 @@ func TestClient_Roots(t *testing.T) {
if got != nil {
t.Errorf("Client.Roots() = %v, want nil", got)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, err.Error(), tt.err.Error())
default:
@ -808,7 +879,7 @@ func TestClient_Federation(t *testing.T) {
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.Federation()
@ -823,8 +894,8 @@ func TestClient_Federation(t *testing.T) {
if got != nil {
t.Errorf("Client.Federation() = %v, want nil", got)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
default:
@ -870,7 +941,7 @@ func TestClient_SSHRoots(t *testing.T) {
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.SSHRoots()
@ -885,8 +956,8 @@ func TestClient_SSHRoots(t *testing.T) {
if got != nil {
t.Errorf("Client.SSHKeys() = %v, want nil", got)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
default:
@ -970,7 +1041,7 @@ func TestClient_RootFingerprint(t *testing.T) {
}
tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.RootFingerprint()
@ -1031,7 +1102,7 @@ func TestClient_SSHBastion(t *testing.T) {
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.SSHBastion(tt.request)
@ -1047,8 +1118,8 @@ func TestClient_SSHBastion(t *testing.T) {
t.Errorf("Client.SSHBastion() = %v, want nil", got)
}
if tt.responseCode != 200 {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, err.Error(), tt.err.Error())
}
@ -1060,3 +1131,28 @@ func TestClient_SSHBastion(t *testing.T) {
})
}
}
func TestClient_GetCaURL(t *testing.T) {
tests := []struct {
name string
caURL string
want string
}{
{"ok", "https://ca.com", "https://ca.com"},
{"ok no schema", "ca.com", "https://ca.com"},
{"ok with port", "https://ca.com:9000", "https://ca.com:9000"},
{"ok with version", "https://ca.com/1.0", "https://ca.com/1.0"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(tt.caURL, WithTransport(http.DefaultTransport))
if err != nil {
t.Errorf("NewClient() error = %v", err)
return
}
if got := c.GetCaURL(); got != tt.want {
t.Errorf("Client.GetCaURL() = %v, want %v", got, tt.want)
}
})
}
}

@ -8,6 +8,7 @@ import (
"net/url"
"os"
"reflect"
"sort"
"testing"
)
@ -196,7 +197,7 @@ func TestLoadClient(t *testing.T) {
switch {
case gotTransport.TLSClientConfig.GetClientCertificate == nil:
t.Error("LoadClient() transport does not define GetClientCertificate")
case !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !reflect.DeepEqual(gotTransport.TLSClientConfig.RootCAs.Subjects(), wantTransport.TLSClientConfig.RootCAs.Subjects()):
case !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !equalPools(gotTransport.TLSClientConfig.RootCAs, wantTransport.TLSClientConfig.RootCAs):
t.Errorf("LoadClient() = %#v, want %#v", got, tt.want)
default:
crt, err := gotTransport.TLSClientConfig.GetClientCertificate(nil)
@ -238,3 +239,23 @@ func Test_defaultsConfig_Validate(t *testing.T) {
})
}
}
// nolint:staticcheck,gocritic
func equalPools(a, b *x509.CertPool) bool {
if reflect.DeepEqual(a, b) {
return true
}
subjects := a.Subjects()
sA := make([]string, len(subjects))
for i := range subjects {
sA[i] = string(subjects[i])
}
subjects = b.Subjects()
sB := make([]string, len(subjects))
for i := range subjects {
sB[i] = string(subjects[i])
}
sort.Strings(sA)
sort.Strings(sB)
return reflect.DeepEqual(sA, sB)
}

@ -346,6 +346,8 @@ func TestIdentity_GetCertPool(t *testing.T) {
return
}
if got != nil {
// nolint:staticcheck // we don't have a different way to check
// the certificates in the pool.
subjects := got.Subjects()
if !reflect.DeepEqual(subjects, tt.wantSubjects) {
t.Errorf("Identity.GetCertPool() = %x, want %x", subjects, tt.wantSubjects)

@ -60,7 +60,10 @@ func NewTLSRenewer(cert *tls.Certificate, fn RenewFunc, opts ...tlsRenewerOption
}
}
period := cert.Leaf.NotAfter.Sub(cert.Leaf.NotBefore)
// Use the current time to calculate the initial period. Using a notBefore
// in the past might set a renewBefore too large, causing continuous
// renewals due to the negative values in nextRenewDuration.
period := cert.Leaf.NotAfter.Sub(time.Now().Truncate(time.Second))
if period < minCertDuration {
return nil, errors.Errorf("period must be greater than or equal to %s, but got %v.", minCertDuration, period)
}
@ -181,7 +184,7 @@ func (r *TLSRenewer) renewCertificate() {
}
func (r *TLSRenewer) nextRenewDuration(notAfter time.Time) time.Duration {
d := time.Until(notAfter) - r.renewBefore
d := time.Until(notAfter).Truncate(time.Second) - r.renewBefore
n := rand.Int63n(int64(r.renewJitter))
d -= time.Duration(n)
if d < 0 {

@ -6,7 +6,7 @@
"password": "password",
"address": "127.0.0.1:0",
"dnsNames": ["127.0.0.1"],
"logger": {"format": "text"},
"_logger": {"format": "text"},
"tls": {
"minVersion": 1.2,
"maxVersion": 1.3,

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save