Merge branch 'master' into crl-support

# Conflicts:
#	api/api.go
#	authority/config/config.go
#	cas/softcas/softcas.go
#	db/db.go
pull/731/head
Raal Goff 2 years ago
commit 60671b07d7

@ -0,0 +1,56 @@
name: Bug Report
description: File a bug report
title: "[Bug]: "
labels: ["bug", "needs triage"]
body:
- type: markdown
attributes:
value: |
Thanks for taking the time to fill out this bug report!
- type: textarea
id: steps
attributes:
label: Steps to Reproduce
description: Tell us how to reproduce this issue.
placeholder: These are the steps!
validations:
required: true
- type: textarea
id: your-env
attributes:
label: Your Environment
value: |-
* OS -
* `step-ca` Version -
validations:
required: true
- type: textarea
id: expected-behavior
attributes:
label: Expected Behavior
description: What did you expect to happen?
validations:
required: true
- type: textarea
id: actual-behavior
attributes:
label: Actual Behavior
description: What happens instead?
validations:
required: true
- type: textarea
id: context
attributes:
label: Additional Context
description: Add any other context about the problem here.
validations:
required: false
- type: textarea
id: contributing
attributes:
label: Contributing
value: |
Vote on this issue by adding a 👍 reaction.
To contribute a fix for this issue, leave a comment (and link to your pull request, if you've opened one already).
validations:
required: false

@ -1,27 +0,0 @@
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: bug, needs triage
assignees: ''
---
### Subject of the issue
Describe your issue here.
### Your environment
* OS -
* Version -
### Steps to reproduce
Tell us how to reproduce this issue. Please provide a working demo, you can use [this template](https://plnkr.co/edit/XorWgI?p=preview) as a base.
### Expected behaviour
Tell us what should happen
### Actual behaviour
Tell us what happens instead
### Additional context
Add any other context about the problem here.

@ -1,12 +1,20 @@
---
name: Documentation Request
about: Request documentation for a feature
title: ''
labels: documentation, needs triage
title: '[Docs]:'
labels: docs, needs triage
assignees: ''
---
## Hello!
<!-- Please leave this section as-is, it's designed to help others in the community know how to interact with our GitHub issues. -->
- Vote on this issue by adding a 👍 reaction
- If you want to document this feature, comment to let us know (we'll work with you on design, scheduling, etc.)
## Affected area/feature
<!---
Tell us which feature you'd like to see documented.
- Where would you like that documentation to live (command line usage output, website, github markdown on the repo)?

@ -1,13 +1,24 @@
---
name: Enhancement
about: Suggest an enhancement to step certificates
about: Suggest an enhancement to step-ca
title: ''
labels: enhancement, needs triage
assignees: ''
---
### What would you like to be added
## Hello!
<!-- Please leave this section as-is,
it's designed to help others in the community know how to interact with our GitHub issues. -->
- Vote on this issue by adding a 👍 reaction
- If you want to implement this feature, comment to let us know (we'll work with you on design, scheduling, etc.)
### Why this is needed
## Issue details
<!-- Enhancement requests are most helpful when they describe the problem you're having
as well as articulating the potential solution you'd like to see built. -->
## Why is this needed?
<!-- Let us know why you think this enhancement would be good for the project or community. -->

@ -1,4 +1,20 @@
### Description
Please describe your pull request.
<!---
Please provide answers in the spaces below each prompt, where applicable.
Not every PR requires responses for each prompt.
Use your discretion.
-->
#### Name of feature:
#### Pain or issue this feature alleviates:
#### Why is this important to the project (if not answered above):
#### Is there documentation on how to use this feature? If so, where?
#### In what environments or workflows is this feature supported?
#### In what environments or workflows is this feature explicitly NOT supported (if any)?
#### Supporting links/other PRs/issues:
💔Thank you!

@ -33,7 +33,7 @@ jobs:
uses: golangci/golangci-lint-action@v2
with:
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
version: 'v1.45.0'
version: 'v1.45.2'
# Optional: working directory, useful for monorepos
# working-directory: somedir
@ -139,7 +139,7 @@ jobs:
name: Run GoReleaser
uses: goreleaser/goreleaser-action@5a54d7e660bda43b405e8463261b3d25631ffe86 # v2.7.0
with:
version: latest
version: 'v1.7.0'
args: release --rm-dist
env:
GITHUB_TOKEN: ${{ secrets.PAT }}

@ -33,7 +33,7 @@ jobs:
uses: golangci/golangci-lint-action@v2
with:
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
version: 'v1.45.0'
version: 'v1.45.2'
# Optional: working directory, useful for monorepos
# working-directory: somedir
@ -59,8 +59,9 @@ jobs:
-
name: Codecov
if: matrix.go == '1.18'
uses: codecov/codecov-action@v1.2.1
uses: codecov/codecov-action@v2
with:
file: ./coverage.out # optional
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.out # optional
name: codecov-umbrella # optional
fail_ci_if_error: true # optional (default = false)

@ -230,42 +230,3 @@ scoop:
# Your app's license
# Default is empty.
license: "Apache-2.0"
#dockers:
# - dockerfile: docker/Dockerfile
# goos: linux
# goarch: amd64
# use_buildx: true
# image_templates:
# - "smallstep/step-cli:latest"
# - "smallstep/step-cli:{{ .Tag }}"
# build_flag_templates:
# - "--platform=linux/amd64"
# - dockerfile: docker/Dockerfile
# goos: linux
# goarch: 386
# use_buildx: true
# image_templates:
# - "smallstep/step-cli:latest"
# - "smallstep/step-cli:{{ .Tag }}"
# build_flag_templates:
# - "--platform=linux/386"
# - dockerfile: docker/Dockerfile
# goos: linux
# goarch: arm
# goarm: 7
# use_buildx: true
# image_templates:
# - "smallstep/step-cli:latest"
# - "smallstep/step-cli:{{ .Tag }}"
# build_flag_templates:
# - "--platform=linux/arm/v7"
# - dockerfile: docker/Dockerfile
# goos: linux
# goarch: arm64
# use_buildx: true
# image_templates:
# - "smallstep/step-cli:latest"
# - "smallstep/step-cli:{{ .Tag }}"
# build_flag_templates:
# - "--platform=linux/arm64/v8"

@ -4,16 +4,70 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
## [Unreleased - 0.18.3] - DATE
### TEMPLATE -- do not alter or remove
---
## [x.y.z] - aaaa-bb-cc
### Added
- Added support for renew after expiry using the claim `allowRenewAfterExpiry`.
### Changed
### Deprecated
### Removed
### Fixed
### Security
---
## [Unreleased]
### Changed
- Certificates signed by an issuer using an RSA key will be signed using the same algorithm as the issuer certificate was signed with. The signature will no longer default to PKCS #1. For example, if the issuer certificate was signed using RSA-PSS with SHA-256, a new certificate will also be signed using RSA-PSS with SHA-256.
## [0.20.0] - 2022-05-26
### Added
- Added Kubernetes auth method for Vault RAs.
- Added support for reporting provisioners to linkedca.
- Added support for certificate policies on authority level.
- Added a Dockerfile with a step-ca build with HSM support.
- A few new WithXX methods for instantiating authorities
### Changed
- Context usage in HTTP APIs.
- Changed authentication for Vault RAs.
- Error message returned to client when authenticating with expired certificate.
- Strip padding from ACME CSRs.
### Deprecated
- HTTP API handler types.
### Fixed
- Fixed SSH revocation.
- CA client dial context for js/wasm target.
- Incomplete `extraNames` support in templates.
- SCEP GET request support.
- Large SCEP request handling.
## [0.19.0] - 2022-04-19
### Added
- Added support for certificate renewals after expiry using the claim `allowRenewalAfterExpiry`.
- Added support for `extraNames` in X.509 templates.
- Added `armv5` builds.
- Added RA support using a Vault instance as the CA.
- Added `WithX509SignerFunc` authority option.
- Added a new `/roots.pem` endpoint to download the CA roots in PEM format.
- Added support for Azure `Managed Identity` tokens.
- Added support for automatic configuration of linked RAs.
- Added support for the `--context` flag. It's now possible to start the
CA with `step-ca --context=abc` to use the configuration from context `abc`.
When a context has been configured and no configuration file is provided
on startup, the configuration for the current context is used.
- Added startup info logging and option to skip it (`--quiet`).
- Added support for renaming the CA (Common Name).
### Changed
- Made SCEP CA URL paths dynamic
- Support two latest versions of Go (1.17, 1.18)
- Made SCEP CA URL paths dynamic.
- Support two latest versions of Go (1.17, 1.18).
- Upgrade go.step.sm/crypto to v0.16.1.
- Upgrade go.step.sm/linkedca to v0.15.0.
### Deprecated
- Go 1.16 support.
### Removed
### Fixed
- Fixed admin credentials on RAs.
- Fixed ACME HTTP-01 challenges for IPv6 identifiers.
- Various improvements under the hood.
### Security
## [0.18.2] - 2022-03-01
@ -49,7 +103,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Support for multiple certificate authority contexts.
- Support for generating extractable keys and certificates on a pkcs#11 module.
### Changed
- Support two latest versions of golang (1.16, 1.17)
- Support two latest versions of Go (1.16, 1.17)
### Deprecated
- go 1.15 support

@ -151,7 +151,7 @@ integration: bin/$(BINNAME)
#########################################
fmt:
$Q gofmt -l -w $(SRC)
$Q gofmt -l -s -w $(SRC)
lint:
$Q golangci-lint run --timeout=30m

@ -54,7 +54,7 @@ Setting up a *public key infrastructure* (PKI) is out of reach for many small te
- [Short-lived certificates](https://smallstep.com/blog/passive-revocation.html) with automated enrollment, renewal, and passive revocation
- Capable of high availability (HA) deployment using [root federation](https://smallstep.com/blog/step-v0.8.3-federation-root-rotation.html) and/or multiple intermediaries
- Can operate as [an online intermediate CA for an existing root CA](https://smallstep.com/docs/tutorials/intermediate-ca-new-ca)
- [Badger, BoltDB, and MySQL database backends](https://smallstep.com/docs/step-ca/configuration#databases)
- [Badger, BoltDB, Postgres, and MySQL database backends](https://smallstep.com/docs/step-ca/configuration#databases)
### ⚙️ Many ways to automate
@ -68,6 +68,7 @@ You can issue certificates in exchange for:
- [Cloud instance identity documents](https://smallstep.com/blog/embarrassingly-easy-certificates-on-aws-azure-gcp/), for VMs on AWS, GCP, and Azure
- [Single-use, short-lived JWK tokens](https://smallstep.com/docs/step-ca/provisioners#jwk) issued by your CD tool — Puppet, Chef, Ansible, Terraform, etc.
- A trusted X.509 certificate (X5C provisioner)
- A host certificate from your Nebula network
- A SCEP challenge (SCEP provisioner)
- An SSH host certificates needing renewal (the SSHPOP provisioner)
- Learn more in our [provisioner documentation](https://smallstep.com/docs/step-ca/provisioners)

@ -7,6 +7,8 @@ import (
"time"
"go.step.sm/crypto/jose"
"github.com/smallstep/certificates/authority/policy"
)
// Account is a subset of the internal account type containing only those
@ -43,15 +45,63 @@ func KeyToID(jwk *jose.JSONWebKey) (string, error) {
return base64.RawURLEncoding.EncodeToString(kid), nil
}
// PolicyNames contains ACME account level policy names
type PolicyNames struct {
DNSNames []string `json:"dns"`
IPRanges []string `json:"ips"`
}
// X509Policy contains ACME account level X.509 policy
type X509Policy struct {
Allowed PolicyNames `json:"allow"`
Denied PolicyNames `json:"deny"`
AllowWildcardNames bool `json:"allowWildcardNames"`
}
// Policy is an ACME Account level policy
type Policy struct {
X509 X509Policy `json:"x509"`
}
func (p *Policy) GetAllowedNameOptions() *policy.X509NameOptions {
if p == nil {
return nil
}
return &policy.X509NameOptions{
DNSDomains: p.X509.Allowed.DNSNames,
IPRanges: p.X509.Allowed.IPRanges,
}
}
func (p *Policy) GetDeniedNameOptions() *policy.X509NameOptions {
if p == nil {
return nil
}
return &policy.X509NameOptions{
DNSDomains: p.X509.Denied.DNSNames,
IPRanges: p.X509.Denied.IPRanges,
}
}
// AreWildcardNamesAllowed returns if wildcard names
// like *.example.com are allowed to be signed.
// Defaults to false.
func (p *Policy) AreWildcardNamesAllowed() bool {
if p == nil {
return false
}
return p.X509.AllowWildcardNames
}
// ExternalAccountKey is an ACME External Account Binding key.
type ExternalAccountKey struct {
ID string `json:"id"`
ProvisionerID string `json:"provisionerID"`
Reference string `json:"reference"`
AccountID string `json:"-"`
KeyBytes []byte `json:"-"`
HmacKey []byte `json:"-"`
CreatedAt time.Time `json:"createdAt"`
BoundAt time.Time `json:"boundAt,omitempty"`
Policy *Policy `json:"policy,omitempty"`
}
// AlreadyBound returns whether this EAK is already bound to
@ -68,6 +118,6 @@ func (eak *ExternalAccountKey) BindTo(account *Account) error {
}
eak.AccountID = account.ID
eak.BoundAt = time.Now()
eak.KeyBytes = []byte{} // clearing the key bytes; can only be used once
eak.HmacKey = []byte{} // clearing the key bytes; can only be used once
return nil
}

@ -7,8 +7,9 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"go.step.sm/crypto/jose"
"github.com/smallstep/assert"
)
func TestKeyToID(t *testing.T) {
@ -95,7 +96,7 @@ func TestExternalAccountKey_BindTo(t *testing.T) {
ID: "eakID",
ProvisionerID: "provID",
Reference: "ref",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
},
acct: &Account{
ID: "accountID",
@ -108,7 +109,7 @@ func TestExternalAccountKey_BindTo(t *testing.T) {
ID: "eakID",
ProvisionerID: "provID",
Reference: "ref",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
AccountID: "someAccountID",
BoundAt: boundAt,
},
@ -138,7 +139,7 @@ func TestExternalAccountKey_BindTo(t *testing.T) {
assert.Equals(t, ae.Subproblems, tt.err.Subproblems)
} else {
assert.Equals(t, eak.AccountID, acct.ID)
assert.Equals(t, eak.KeyBytes, []byte{})
assert.Equals(t, eak.HmacKey, []byte{})
assert.NotNil(t, eak.BoundAt)
}
})

@ -67,8 +67,11 @@ func (u *UpdateAccountRequest) Validate() error {
}
// NewAccount is the handler resource for creating new ACME accounts.
func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
func NewAccount(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
payload, err := payloadFromContext(ctx)
if err != nil {
render.Error(w, err)
@ -114,7 +117,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
return
}
eak, err := h.validateExternalAccountBinding(ctx, &nar)
eak, err := validateExternalAccountBinding(ctx, &nar)
if err != nil {
render.Error(w, err)
return
@ -125,18 +128,17 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
Contact: nar.Contact,
Status: acme.StatusValid,
}
if err := h.db.CreateAccount(ctx, acc); err != nil {
if err := db.CreateAccount(ctx, acc); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error creating account"))
return
}
if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response
err := eak.BindTo(acc)
if err != nil {
if err := eak.BindTo(acc); err != nil {
render.Error(w, err)
return
}
if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil {
if err := db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key"))
return
}
@ -147,15 +149,18 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
httpStatus = http.StatusOK
}
h.linker.LinkAccount(ctx, acc)
linker.LinkAccount(ctx, acc)
w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID))
w.Header().Set("Location", linker.GetLink(r.Context(), acme.AccountLinkType, acc.ID))
render.JSONStatus(w, acc, httpStatus)
}
// GetOrUpdateAccount is the api for updating an ACME account.
func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx)
if err != nil {
render.Error(w, err)
@ -187,16 +192,16 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
acc.Contact = uar.Contact
}
if err := h.db.UpdateAccount(ctx, acc); err != nil {
if err := db.UpdateAccount(ctx, acc); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error updating account"))
return
}
}
}
h.linker.LinkAccount(ctx, acc)
linker.LinkAccount(ctx, acc)
w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID))
w.Header().Set("Location", linker.GetLink(ctx, acme.AccountLinkType, acc.ID))
render.JSON(w, acc)
}
@ -210,8 +215,11 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) {
}
// GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account.
func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx)
if err != nil {
render.Error(w, err)
@ -222,13 +230,14 @@ func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID))
return
}
orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID)
orders, err := db.GetOrdersByAccountID(ctx, acc.ID)
if err != nil {
render.Error(w, err)
return
}
h.linker.LinkOrdersByAccountID(ctx, orders)
linker.LinkOrdersByAccountID(ctx, orders)
render.JSON(w, orders)
logOrdersByAccount(w, orders)

@ -13,10 +13,12 @@ import (
"github.com/go-chi/chi"
"github.com/pkg/errors"
"go.step.sm/crypto/jose"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/crypto/jose"
)
var (
@ -29,6 +31,22 @@ var (
}
)
type fakeProvisioner struct{}
func (*fakeProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error {
return nil
}
func (*fakeProvisioner) AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) {
return nil, nil
}
func (*fakeProvisioner) AuthorizeRevoke(ctx context.Context, token string) error { return nil }
func (*fakeProvisioner) GetID() string { return "" }
func (*fakeProvisioner) GetName() string { return "" }
func (*fakeProvisioner) DefaultTLSCertDuration() time.Duration { return 0 }
func (*fakeProvisioner) GetOptions() *provisioner.Options { return nil }
func newProv() acme.Provisioner {
// Initialize provisioners
p := &provisioner.ACME{
@ -41,6 +59,19 @@ func newProv() acme.Provisioner {
return p
}
func newProvWithOptions(options *provisioner.Options) acme.Provisioner {
// Initialize provisioners
p := &provisioner.ACME{
Type: "ACME",
Name: "test@acme-<test>provisioner.com",
Options: options,
}
if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
fmt.Printf("%v", err)
}
return p
}
func newACMEProv(t *testing.T) *provisioner.ACME {
p := newProv()
a, ok := p.(*provisioner.ACME)
@ -50,6 +81,15 @@ func newACMEProv(t *testing.T) *provisioner.ACME {
return a
}
func newACMEProvWithOptions(t *testing.T, options *provisioner.Options) *provisioner.ACME {
p := newProvWithOptions(options)
a, ok := p.(*provisioner.ACME)
if !ok {
t.Fatal("not a valid ACME provisioner")
}
return a
}
func createEABJWS(jwk *jose.JSONWebKey, hmacKey []byte, keyID, u string) (*jose.JSONWebSignature, error) {
signer, err := jose.NewSigner(
jose.SigningKey{
@ -296,10 +336,9 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
},
"ok": func(t *testing.T) test {
acc := &acme.Account{ID: accID}
ctx := context.WithValue(context.Background(), accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, accContextKey, acc)
return test{
db: &acme.MockDB{
MockGetOrdersByAccountID: func(ctx context.Context, id string) ([]string, error) {
@ -315,11 +354,11 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
h.GetOrdersByAccountID(w, req)
GetOrdersByAccountID(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -363,6 +402,7 @@ func TestHandler_NewAccount(t *testing.T) {
var tests = map[string]func(t *testing.T) test{
"fail/no-payload": func(t *testing.T) test {
return test{
db: &acme.MockDB{},
ctx: context.Background(),
statusCode: 500,
err: acme.NewErrorISE("payload expected in request context"),
@ -371,6 +411,7 @@ func TestHandler_NewAccount(t *testing.T) {
"fail/nil-payload": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), payloadContextKey, nil)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("payload expected in request context"),
@ -379,6 +420,7 @@ func TestHandler_NewAccount(t *testing.T) {
"fail/unmarshal-payload-error": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{})
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "failed to "+
@ -393,6 +435,7 @@ func TestHandler_NewAccount(t *testing.T) {
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
@ -405,8 +448,9 @@ func TestHandler_NewAccount(t *testing.T) {
b, err := json.Marshal(nar)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = acme.NewProvisionerContext(ctx, prov)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -418,9 +462,10 @@ func TestHandler_NewAccount(t *testing.T) {
}
b, err := json.Marshal(nar)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("jwk expected in request context"),
@ -432,10 +477,11 @@ func TestHandler_NewAccount(t *testing.T) {
}
b, err := json.Marshal(nar)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, nil)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("jwk expected in request context"),
@ -454,9 +500,9 @@ func TestHandler_NewAccount(t *testing.T) {
prov.RequireEAB = true
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = acme.NewProvisionerContext(ctx, prov)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 400,
err: acme.NewError(acme.ErrorExternalAccountRequiredType, "no external account binding provided"),
@ -471,7 +517,7 @@ func TestHandler_NewAccount(t *testing.T) {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, jwkContextKey, jwk)
return test{
db: &acme.MockDB{
@ -501,18 +547,11 @@ func TestHandler_NewAccount(t *testing.T) {
}
b, err := json.Marshal(nar)
assert.FatalError(t, err)
scepProvisioner := &provisioner.SCEP{
Type: "SCEP",
Name: "test@scep-<test>provisioner.com",
}
if err := scepProvisioner.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
assert.FatalError(t, err)
}
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner)
ctx = acme.NewProvisionerContext(ctx, &fakeProvisioner{})
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewError(acme.ErrorServerInternalType, "provisioner in context is not an ACME provisioner"),
@ -551,14 +590,13 @@ func TestHandler_NewAccount(t *testing.T) {
prov.RequireEAB = true
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
eak := &acme.ExternalAccountKey{
ID: "eakID",
ProvisionerID: provID,
Reference: "testeak",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Now(),
}
return test{
@ -599,8 +637,7 @@ func TestHandler_NewAccount(t *testing.T) {
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = acme.NewProvisionerContext(ctx, prov)
return test{
db: &acme.MockDB{
MockCreateAccount: func(ctx context.Context, acc *acme.Account) error {
@ -635,11 +672,11 @@ func TestHandler_NewAccount(t *testing.T) {
Status: acme.StatusValid,
Contact: []string{"foo", "bar"},
}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
db: &acme.MockDB{},
ctx: ctx,
acc: acc,
statusCode: 200,
@ -664,8 +701,7 @@ func TestHandler_NewAccount(t *testing.T) {
prov.RequireEAB = false
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = acme.NewProvisionerContext(ctx, prov)
return test{
db: &acme.MockDB{
MockCreateAccount: func(ctx context.Context, acc *acme.Account) error {
@ -719,8 +755,7 @@ func TestHandler_NewAccount(t *testing.T) {
prov.RequireEAB = true
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
db: &acme.MockDB{
@ -735,7 +770,7 @@ func TestHandler_NewAccount(t *testing.T) {
ID: "eakID",
ProvisionerID: provID,
Reference: "testeak",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Now(),
}, nil
},
@ -759,11 +794,11 @@ func TestHandler_NewAccount(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
h.NewAccount(w, req)
NewAccount(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -814,6 +849,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test {
return test{
db: &acme.MockDB{},
ctx: context.Background(),
statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -822,6 +858,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
"fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), accContextKey, nil)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -830,6 +867,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
"fail/no-payload": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), accContextKey, &acc)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("payload expected in request context"),
@ -839,6 +877,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
ctx := context.WithValue(context.Background(), accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("payload expected in request context"),
@ -848,6 +887,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
ctx := context.WithValue(context.Background(), accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-account request payload: unexpected end of JSON input"),
@ -862,6 +902,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
ctx := context.WithValue(context.Background(), accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
@ -894,10 +935,9 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
}
b, err := json.Marshal(uar)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
db: &acme.MockDB{
MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error {
@ -914,11 +954,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
uar := &UpdateAccountRequest{}
b, err := json.Marshal(uar)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 200,
}
@ -929,10 +969,9 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
}
b, err := json.Marshal(uar)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
db: &acme.MockDB{
MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error {
@ -946,11 +985,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
}
},
"ok/post-as-get": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 200,
}
@ -959,11 +998,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
h.GetOrUpdateAccount(w, req)
GetOrUpdateAccount(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)

@ -4,8 +4,9 @@ import (
"context"
"encoding/json"
"github.com/smallstep/certificates/acme"
"go.step.sm/crypto/jose"
"github.com/smallstep/certificates/acme"
)
// ExternalAccountBinding represents the ACME externalAccountBinding JWS
@ -16,7 +17,7 @@ type ExternalAccountBinding struct {
}
// validateExternalAccountBinding validates the externalAccountBinding property in a call to new-account.
func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) {
func validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) {
acmeProv, err := acmeProvisionerFromContext(ctx)
if err != nil {
return nil, acme.WrapErrorISE(err, "could not load ACME provisioner from context")
@ -47,7 +48,8 @@ func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAc
return nil, acmeErr
}
externalAccountKey, err := h.db.GetExternalAccountKey(ctx, acmeProv.ID, keyID)
db := acme.MustDatabaseFromContext(ctx)
externalAccountKey, err := db.GetExternalAccountKey(ctx, acmeProv.ID, keyID)
if err != nil {
if _, ok := err.(*acme.Error); ok {
return nil, acme.WrapError(acme.ErrorUnauthorizedType, err, "the field 'kid' references an unknown key")
@ -55,11 +57,19 @@ func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAc
return nil, acme.WrapErrorISE(err, "error retrieving external account key")
}
if externalAccountKey == nil {
return nil, acme.NewError(acme.ErrorUnauthorizedType, "the field 'kid' references an unknown key")
}
if len(externalAccountKey.HmacKey) == 0 {
return nil, acme.NewError(acme.ErrorServerInternalType, "external account binding key with id '%s' does not have secret bytes", keyID)
}
if externalAccountKey.AlreadyBound() {
return nil, acme.NewError(acme.ErrorUnauthorizedType, "external account binding key with id '%s' was already bound to account '%s' on %s", keyID, externalAccountKey.AccountID, externalAccountKey.BoundAt)
}
payload, err := eabJWS.Verify(externalAccountKey.KeyBytes)
payload, err := eabJWS.Verify(externalAccountKey.HmacKey)
if err != nil {
return nil, acme.WrapErrorISE(err, "error verifying externalAccountBinding signature")
}
@ -97,12 +107,12 @@ func keysAreEqual(x, y *jose.JSONWebKey) bool {
// validateEABJWS verifies the contents of the External Account Binding JWS.
// The protected header of the JWS MUST meet the following criteria:
// o The "alg" field MUST indicate a MAC-based algorithm
// o The "kid" field MUST contain the key identifier provided by the CA
// o The "nonce" field MUST NOT be present
// o The "url" field MUST be set to the same value as the outer JWS
//
// - The "alg" field MUST indicate a MAC-based algorithm
// - The "kid" field MUST contain the key identifier provided by the CA
// - The "nonce" field MUST NOT be present
// - The "url" field MUST be set to the same value as the outer JWS
func validateEABJWS(ctx context.Context, jws *jose.JSONWebSignature) (string, *acme.Error) {
if jws == nil {
return "", acme.NewErrorISE("no JWS provided")
}

@ -9,10 +9,11 @@ import (
"time"
"github.com/pkg/errors"
"go.step.sm/crypto/jose"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/crypto/jose"
)
func Test_keysAreEqual(t *testing.T) {
@ -98,8 +99,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
assert.FatalError(t, err)
prov := newACMEProv(t)
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = acme.NewProvisionerContext(ctx, prov)
return test{
db: &acme.MockDB{},
ctx: ctx,
@ -143,8 +143,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t)
prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
createdAt := time.Now()
return test{
@ -154,7 +153,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
ID: "eakID",
ProvisionerID: provID,
Reference: "testeak",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: createdAt,
}, nil
},
@ -168,7 +167,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
ID: "eakID",
ProvisionerID: provID,
Reference: "testeak",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: createdAt,
},
err: nil,
@ -189,17 +188,10 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
}
b, err := json.Marshal(nar)
assert.FatalError(t, err)
scepProvisioner := &provisioner.SCEP{
Type: "SCEP",
Name: "test@scep-<test>provisioner.com",
}
if err := scepProvisioner.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
assert.FatalError(t, err)
}
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner)
ctx = acme.NewProvisionerContext(ctx, &fakeProvisioner{})
return test{
ctx: ctx,
err: acme.NewError(acme.ErrorServerInternalType, "could not load ACME provisioner from context: provisioner in context is not an ACME provisioner"),
@ -218,8 +210,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t)
prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = acme.NewProvisionerContext(ctx, prov)
return test{
db: &acme.MockDB{},
ctx: ctx,
@ -264,8 +255,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t)
prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
db: &acme.MockDB{},
@ -310,8 +300,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t)
prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
db: &acme.MockDB{
@ -358,8 +347,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t)
prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
db: &acme.MockDB{
@ -408,8 +396,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t)
prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
db: &acme.MockDB{
@ -426,6 +413,112 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
err: acme.NewErrorISE("error retrieving external account key"),
}
},
"fail/db.GetExternalAccountKey-nil": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName)
rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url)
assert.FatalError(t, err)
eab := &ExternalAccountBinding{}
err = json.Unmarshal(rawEABJWS, &eab)
assert.FatalError(t, err)
nar := &NewAccountRequest{
Contact: []string{"foo", "bar"},
ExternalAccountBinding: eab,
}
payloadBytes, err := json.Marshal(nar)
assert.FatalError(t, err)
so := new(jose.SignerOptions)
so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm))
so.WithHeader("url", url)
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
Key: jwk.Key,
}, so)
assert.FatalError(t, err)
jws, err := signer.Sign(payloadBytes)
assert.FatalError(t, err)
raw, err := jws.CompactSerialize()
assert.FatalError(t, err)
parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err)
prov := newACMEProv(t)
prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
db: &acme.MockDB{
MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) {
return nil, nil
},
},
ctx: ctx,
nar: &NewAccountRequest{
Contact: []string{"foo", "bar"},
ExternalAccountBinding: eab,
},
eak: nil,
err: acme.NewError(acme.ErrorUnauthorizedType, "the field 'kid' references an unknown key"),
}
},
"fail/db.GetExternalAccountKey-no-keybytes": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName)
rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url)
assert.FatalError(t, err)
eab := &ExternalAccountBinding{}
err = json.Unmarshal(rawEABJWS, &eab)
assert.FatalError(t, err)
nar := &NewAccountRequest{
Contact: []string{"foo", "bar"},
ExternalAccountBinding: eab,
}
payloadBytes, err := json.Marshal(nar)
assert.FatalError(t, err)
so := new(jose.SignerOptions)
so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm))
so.WithHeader("url", url)
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
Key: jwk.Key,
}, so)
assert.FatalError(t, err)
jws, err := signer.Sign(payloadBytes)
assert.FatalError(t, err)
raw, err := jws.CompactSerialize()
assert.FatalError(t, err)
parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err)
prov := newACMEProv(t)
prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
createdAt := time.Now()
return test{
db: &acme.MockDB{
MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) {
return &acme.ExternalAccountKey{
ID: "eakID",
ProvisionerID: provID,
Reference: "testeak",
CreatedAt: createdAt,
AccountID: "some-account-id",
HmacKey: []byte{},
}, nil
},
},
ctx: ctx,
nar: &NewAccountRequest{
Contact: []string{"foo", "bar"},
ExternalAccountBinding: eab,
},
eak: nil,
err: acme.NewError(acme.ErrorServerInternalType, "external account binding key with id 'eakID' does not have secret bytes"),
}
},
"fail/db.GetExternalAccountKey-wrong-provisioner": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
@ -458,8 +551,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t)
prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
db: &acme.MockDB{
@ -506,8 +598,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t)
prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
createdAt := time.Now()
boundAt := time.Now().Add(1 * time.Second)
@ -520,6 +611,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
Reference: "testeak",
CreatedAt: createdAt,
AccountID: "some-account-id",
HmacKey: []byte{1, 3, 3, 7},
BoundAt: boundAt,
}, nil
},
@ -565,8 +657,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t)
prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
db: &acme.MockDB{
@ -575,7 +666,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
ID: "eakID",
ProvisionerID: provID,
Reference: "testeak",
KeyBytes: []byte{1, 2, 3, 4},
HmacKey: []byte{1, 2, 3, 4},
CreatedAt: time.Now(),
}, nil
},
@ -623,8 +714,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t)
prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
db: &acme.MockDB{
@ -633,7 +723,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
ID: "eakID",
ProvisionerID: provID,
Reference: "testeak",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Now(),
}, nil
},
@ -678,8 +768,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
assert.FatalError(t, err)
prov := newACMEProv(t)
prov.RequireEAB = true
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
db: &acme.MockDB{
@ -688,7 +777,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
ID: "eakID",
ProvisionerID: provID,
Reference: "testeak",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Now(),
}, nil
},
@ -734,8 +823,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t)
prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, nil)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
db: &acme.MockDB{
@ -744,7 +832,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
ID: "eakID",
ProvisionerID: provID,
Reference: "testeak",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Now(),
}, nil
},
@ -762,10 +850,8 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
db: tc.db,
}
got, err := h.validateExternalAccountBinding(tc.ctx, tc.nar)
ctx := acme.NewDatabaseContext(tc.ctx, tc.db)
got, err := validateExternalAccountBinding(ctx, tc.nar)
wantErr := tc.err != nil
gotErr := err != nil
if wantErr != gotErr {
@ -787,7 +873,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
} else {
assert.NotNil(t, tc.eak)
assert.Equals(t, got.ID, tc.eak.ID)
assert.Equals(t, got.KeyBytes, tc.eak.KeyBytes)
assert.Equals(t, got.HmacKey, tc.eak.HmacKey)
assert.Equals(t, got.ProvisionerID, tc.eak.ProvisionerID)
assert.Equals(t, got.Reference, tc.eak.Reference)
assert.Equals(t, got.CreatedAt, tc.eak.CreatedAt)

@ -2,12 +2,10 @@ package api
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"net"
"net/http"
"time"
@ -16,6 +14,7 @@ import (
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
)
@ -39,111 +38,152 @@ type payloadInfo struct {
isEmptyJSON bool
}
// Handler is the ACME API request handler.
type Handler struct {
db acme.DB
backdate provisioner.Duration
ca acme.CertificateAuthority
linker Linker
validateChallengeOptions *acme.ValidateChallengeOptions
prerequisitesChecker func(ctx context.Context) (bool, error)
}
// HandlerOptions required to create a new ACME API request handler.
type HandlerOptions struct {
Backdate provisioner.Duration
// DB storage backend that impements the acme.DB interface.
// DB storage backend that implements the acme.DB interface.
//
// Deprecated: use acme.NewContex(context.Context, acme.DB)
DB acme.DB
// CA is the certificate authority interface.
//
// Deprecated: use authority.NewContext(context.Context, *authority.Authority)
CA acme.CertificateAuthority
// Backdate is the duration that the CA will subtract from the current time
// to set the NotBefore in the certificate.
Backdate provisioner.Duration
// DNS the host used to generate accurate ACME links. By default the authority
// will use the Host from the request, so this value will only be used if
// request.Host is empty.
DNS string
// Prefix is a URL path prefix under which the ACME api is served. This
// prefix is required to generate accurate ACME links.
// E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account --
// "acme" is the prefix from which the ACME api is accessed.
Prefix string
CA acme.CertificateAuthority
// PrerequisitesChecker checks if all prerequisites for serving ACME are
// met by the CA configuration.
PrerequisitesChecker func(ctx context.Context) (bool, error)
}
var mustAuthority = func(ctx context.Context) acme.CertificateAuthority {
return authority.MustFromContext(ctx)
}
// handler is the ACME API request handler.
type handler struct {
opts *HandlerOptions
}
// Route traffic and implement the Router interface. For backward compatibility
// this route adds will add a new middleware that will set the ACME components
// on the context.
//
// Note: this method is deprecated in step-ca, other applications can still use
// this to support ACME, but the recommendation is to use use
// api.Route(api.Router) and acme.NewContext() instead.
func (h *handler) Route(r api.Router) {
client := acme.NewClient()
linker := acme.NewLinker(h.opts.DNS, h.opts.Prefix)
route(r, func(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if ca, ok := h.opts.CA.(*authority.Authority); ok && ca != nil {
ctx = authority.NewContext(ctx, ca)
}
ctx = acme.NewContext(ctx, h.opts.DB, client, linker, h.opts.PrerequisitesChecker)
next(w, r.WithContext(ctx))
}
})
}
// NewHandler returns a new ACME API handler.
func NewHandler(ops HandlerOptions) api.RouterHandler {
transport := &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
client := http.Client{
Timeout: 30 * time.Second,
Transport: transport,
}
dialer := &net.Dialer{
Timeout: 30 * time.Second,
}
prerequisitesChecker := func(ctx context.Context) (bool, error) {
// by default all prerequisites are met
return true, nil
}
if ops.PrerequisitesChecker != nil {
prerequisitesChecker = ops.PrerequisitesChecker
}
return &Handler{
ca: ops.CA,
db: ops.DB,
backdate: ops.Backdate,
linker: NewLinker(ops.DNS, ops.Prefix),
validateChallengeOptions: &acme.ValidateChallengeOptions{
HTTPGet: client.Get,
LookupTxt: net.LookupTXT,
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.DialWithDialer(dialer, network, addr, config)
},
},
prerequisitesChecker: prerequisitesChecker,
//
// Note: this method is deprecated in step-ca, other applications can still use
// this to support ACME, but the recommendation is to use use
// api.Route(api.Router) and acme.NewContext() instead.
func NewHandler(opts HandlerOptions) api.RouterHandler {
return &handler{
opts: &opts,
}
}
// Route traffic and implement the Router interface.
func (h *Handler) Route(r api.Router) {
getPath := h.linker.GetUnescapedPathSuffix
// Standard ACME API
r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce))))))
r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce))))))
r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory))))
r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory))))
// Route traffic and implement the Router interface. This method requires that
// all the acme components, authority, db, client, linker, and prerequisite
// checker to be present in the context.
func Route(r api.Router) {
route(r, nil)
}
func route(r api.Router, middleware func(next nextHTTP) nextHTTP) {
commonMiddleware := func(next nextHTTP) nextHTTP {
handler := func(w http.ResponseWriter, r *http.Request) {
// Linker middleware gets the provisioner and current url from the
// request and sets them in the context.
linker := acme.MustLinkerFromContext(r.Context())
linker.Middleware(http.HandlerFunc(checkPrerequisites(next))).ServeHTTP(w, r)
}
if middleware != nil {
handler = middleware(handler)
}
return handler
}
validatingMiddleware := func(next nextHTTP) nextHTTP {
return h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next))))))))
return commonMiddleware(addNonce(addDirLink(verifyContentType(parseJWS(validateJWS(next))))))
}
extractPayloadByJWK := func(next nextHTTP) nextHTTP {
return validatingMiddleware(h.extractJWK(h.verifyAndExtractJWSPayload(next)))
return validatingMiddleware(extractJWK(verifyAndExtractJWSPayload(next)))
}
extractPayloadByKid := func(next nextHTTP) nextHTTP {
return validatingMiddleware(h.lookupJWK(h.verifyAndExtractJWSPayload(next)))
return validatingMiddleware(lookupJWK(verifyAndExtractJWSPayload(next)))
}
extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP {
return validatingMiddleware(h.extractOrLookupJWK(h.verifyAndExtractJWSPayload(next)))
return validatingMiddleware(extractOrLookupJWK(verifyAndExtractJWSPayload(next)))
}
r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), extractPayloadByJWK(h.NewAccount))
r.MethodFunc("POST", getPath(AccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.GetOrUpdateAccount))
r.MethodFunc("POST", getPath(KeyChangeLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.NotImplemented))
r.MethodFunc("POST", getPath(NewOrderLinkType, "{provisionerID}"), extractPayloadByKid(h.NewOrder))
r.MethodFunc("POST", getPath(OrderLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder)))
r.MethodFunc("POST", getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccountID)))
r.MethodFunc("POST", getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.FinalizeOrder))
r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization)))
r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge))
r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate)))
r.MethodFunc("POST", getPath(RevokeCertLinkType, "{provisionerID}"), extractPayloadByKidOrJWK(h.RevokeCert))
getPath := acme.GetUnescapedPathSuffix
// Standard ACME API
r.MethodFunc("GET", getPath(acme.NewNonceLinkType, "{provisionerID}"),
commonMiddleware(addNonce(addDirLink(GetNonce))))
r.MethodFunc("HEAD", getPath(acme.NewNonceLinkType, "{provisionerID}"),
commonMiddleware(addNonce(addDirLink(GetNonce))))
r.MethodFunc("GET", getPath(acme.DirectoryLinkType, "{provisionerID}"),
commonMiddleware(GetDirectory))
r.MethodFunc("HEAD", getPath(acme.DirectoryLinkType, "{provisionerID}"),
commonMiddleware(GetDirectory))
r.MethodFunc("POST", getPath(acme.NewAccountLinkType, "{provisionerID}"),
extractPayloadByJWK(NewAccount))
r.MethodFunc("POST", getPath(acme.AccountLinkType, "{provisionerID}", "{accID}"),
extractPayloadByKid(GetOrUpdateAccount))
r.MethodFunc("POST", getPath(acme.KeyChangeLinkType, "{provisionerID}", "{accID}"),
extractPayloadByKid(NotImplemented))
r.MethodFunc("POST", getPath(acme.NewOrderLinkType, "{provisionerID}"),
extractPayloadByKid(NewOrder))
r.MethodFunc("POST", getPath(acme.OrderLinkType, "{provisionerID}", "{ordID}"),
extractPayloadByKid(isPostAsGet(GetOrder)))
r.MethodFunc("POST", getPath(acme.OrdersByAccountLinkType, "{provisionerID}", "{accID}"),
extractPayloadByKid(isPostAsGet(GetOrdersByAccountID)))
r.MethodFunc("POST", getPath(acme.FinalizeLinkType, "{provisionerID}", "{ordID}"),
extractPayloadByKid(FinalizeOrder))
r.MethodFunc("POST", getPath(acme.AuthzLinkType, "{provisionerID}", "{authzID}"),
extractPayloadByKid(isPostAsGet(GetAuthorization)))
r.MethodFunc("POST", getPath(acme.ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"),
extractPayloadByKid(GetChallenge))
r.MethodFunc("POST", getPath(acme.CertificateLinkType, "{provisionerID}", "{certID}"),
extractPayloadByKid(isPostAsGet(GetCertificate)))
r.MethodFunc("POST", getPath(acme.RevokeCertLinkType, "{provisionerID}"),
extractPayloadByKidOrJWK(RevokeCert))
}
// GetNonce just sets the right header since a Nonce is added to each response
// by middleware by default.
func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) {
func GetNonce(w http.ResponseWriter, r *http.Request) {
if r.Method == "HEAD" {
w.WriteHeader(http.StatusOK)
} else {
@ -179,7 +219,7 @@ func (d *Directory) ToLog() (interface{}, error) {
// GetDirectory is the ACME resource for returning a directory configuration
// for client configuration.
func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
func GetDirectory(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
acmeProv, err := acmeProvisionerFromContext(ctx)
if err != nil {
@ -187,12 +227,13 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
return
}
linker := acme.MustLinkerFromContext(ctx)
render.JSON(w, &Directory{
NewNonce: h.linker.GetLink(ctx, NewNonceLinkType),
NewAccount: h.linker.GetLink(ctx, NewAccountLinkType),
NewOrder: h.linker.GetLink(ctx, NewOrderLinkType),
RevokeCert: h.linker.GetLink(ctx, RevokeCertLinkType),
KeyChange: h.linker.GetLink(ctx, KeyChangeLinkType),
NewNonce: linker.GetLink(ctx, acme.NewNonceLinkType),
NewAccount: linker.GetLink(ctx, acme.NewAccountLinkType),
NewOrder: linker.GetLink(ctx, acme.NewOrderLinkType),
RevokeCert: linker.GetLink(ctx, acme.RevokeCertLinkType),
KeyChange: linker.GetLink(ctx, acme.KeyChangeLinkType),
Meta: Meta{
ExternalAccountRequired: acmeProv.RequireEAB,
},
@ -201,19 +242,22 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
// NotImplemented returns a 501 and is generally a placeholder for functionality which
// MAY be added at some point in the future but is not in any way a guarantee of such.
func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) {
func NotImplemented(w http.ResponseWriter, r *http.Request) {
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented"))
}
// GetAuthorization ACME api for retrieving an Authz.
func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) {
func GetAuthorization(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx)
if err != nil {
render.Error(w, err)
return
}
az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID"))
az, err := db.GetAuthorization(ctx, chi.URLParam(r, "authzID"))
if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization"))
return
@ -223,20 +267,23 @@ func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) {
"account '%s' does not own authorization '%s'", acc.ID, az.ID))
return
}
if err = az.UpdateStatus(ctx, h.db); err != nil {
if err = az.UpdateStatus(ctx, db); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error updating authorization status"))
return
}
h.linker.LinkAuthorization(ctx, az)
linker.LinkAuthorization(ctx, az)
w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID))
w.Header().Set("Location", linker.GetLink(ctx, acme.AuthzLinkType, az.ID))
render.JSON(w, az)
}
// GetChallenge ACME api for retrieving a Challenge.
func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
func GetChallenge(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx)
if err != nil {
render.Error(w, err)
@ -257,7 +304,7 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
// we'll just ignore the body.
azID := chi.URLParam(r, "authzID")
ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID)
ch, err := db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID)
if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge"))
return
@ -273,29 +320,31 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
render.Error(w, err)
return
}
if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil {
if err = ch.Validate(ctx, db, jwk); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error validating challenge"))
return
}
h.linker.LinkChallenge(ctx, ch, azID)
linker.LinkChallenge(ctx, ch, azID)
w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up"))
w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID))
w.Header().Add("Link", link(linker.GetLink(ctx, acme.AuthzLinkType, azID), "up"))
w.Header().Set("Location", linker.GetLink(ctx, acme.ChallengeLinkType, azID, ch.ID))
render.JSON(w, ch)
}
// GetCertificate ACME api for retrieving a Certificate.
func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) {
func GetCertificate(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
acc, err := accountFromContext(ctx)
if err != nil {
render.Error(w, err)
return
}
certID := chi.URLParam(r, "certID")
cert, err := h.db.GetCertificate(ctx, certID)
certID := chi.URLParam(r, "certID")
cert, err := db.GetCertificate(ctx, certID)
if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate"))
return

@ -3,6 +3,7 @@ package api
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"encoding/pem"
@ -19,11 +20,33 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil"
)
type mockClient struct {
get func(url string) (*http.Response, error)
lookupTxt func(name string) ([]string, error)
tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error)
}
func (m *mockClient) Get(u string) (*http.Response, error) { return m.get(u) }
func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) }
func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
return m.tlsDial(network, addr, config)
}
func mockMustAuthority(t *testing.T, a acme.CertificateAuthority) {
t.Helper()
fn := mustAuthority
t.Cleanup(func() {
mustAuthority = fn
})
mustAuthority = func(ctx context.Context) acme.CertificateAuthority {
return a
}
}
func TestHandler_GetNonce(t *testing.T) {
tests := []struct {
name string
@ -38,10 +61,10 @@ func TestHandler_GetNonce(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := &Handler{}
// h := &Handler{}
w := httptest.NewRecorder()
req.Method = tt.name
h.GetNonce(w, req)
GetNonce(w, req)
res := w.Result()
if res.StatusCode != tt.statusCode {
@ -52,7 +75,8 @@ func TestHandler_GetNonce(t *testing.T) {
}
func TestHandler_GetDirectory(t *testing.T) {
linker := NewLinker("ca.smallstep.com", "acme")
linker := acme.NewLinker("ca.smallstep.com", "acme")
_ = linker
type test struct {
ctx context.Context
statusCode int
@ -61,23 +85,14 @@ func TestHandler_GetDirectory(t *testing.T) {
}
var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
ctx: ctx,
ctx: context.Background(),
statusCode: 500,
err: acme.NewErrorISE("provisioner in context is not an ACME provisioner"),
err: acme.NewErrorISE("provisioner is not in context"),
}
},
"fail/different-provisioner": func(t *testing.T) test {
prov := &provisioner.SCEP{
Type: "SCEP",
Name: "test@scep-<test>provisioner.com",
}
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx := acme.NewProvisionerContext(context.Background(), &fakeProvisioner{})
return test{
ctx: ctx,
statusCode: 500,
@ -88,8 +103,7 @@ func TestHandler_GetDirectory(t *testing.T) {
prov := newProv()
provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx := acme.NewProvisionerContext(context.Background(), prov)
expDir := Directory{
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
@ -108,8 +122,7 @@ func TestHandler_GetDirectory(t *testing.T) {
prov.RequireEAB = true
provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx := acme.NewProvisionerContext(context.Background(), prov)
expDir := Directory{
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
@ -130,11 +143,11 @@ func TestHandler_GetDirectory(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{linker: linker}
ctx := acme.NewLinkerContext(tc.ctx, acme.NewLinker("test.ca.smallstep.com", "acme"))
req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
h.GetDirectory(w, req)
GetDirectory(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -219,7 +232,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
}
},
"fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, nil)
return test{
db: &acme.MockDB{},
@ -285,10 +298,9 @@ func TestHandler_GetAuthorization(t *testing.T) {
},
"ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
db: &acme.MockDB{
MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) {
@ -304,11 +316,11 @@ func TestHandler_GetAuthorization(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
h.GetAuthorization(w, req)
GetAuthorization(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -447,11 +459,11 @@ func TestHandler_GetCertificate(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db}
ctx := acme.NewDatabaseContext(tc.ctx, tc.db)
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
h.GetCertificate(w, req)
GetCertificate(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -491,7 +503,7 @@ func TestHandler_GetChallenge(t *testing.T) {
type test struct {
db acme.DB
vco *acme.ValidateChallengeOptions
vc acme.Client
ctx context.Context
statusCode int
ch *acme.Challenge
@ -500,6 +512,7 @@ func TestHandler_GetChallenge(t *testing.T) {
var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test {
return test{
db: &acme.MockDB{},
ctx: context.Background(),
statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -507,6 +520,7 @@ func TestHandler_GetChallenge(t *testing.T) {
},
"fail/nil-account": func(t *testing.T) test {
return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), accContextKey, nil),
statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -516,6 +530,7 @@ func TestHandler_GetChallenge(t *testing.T) {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), accContextKey, acc)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("payload expected in request context"),
@ -523,10 +538,11 @@ func TestHandler_GetChallenge(t *testing.T) {
},
"fail/nil-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("payload expected in request context"),
@ -534,7 +550,7 @@ func TestHandler_GetChallenge(t *testing.T) {
},
"fail/db.GetChallenge-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -553,7 +569,7 @@ func TestHandler_GetChallenge(t *testing.T) {
},
"fail/account-id-mismatch": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -572,7 +588,7 @@ func TestHandler_GetChallenge(t *testing.T) {
},
"fail/no-jwk": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -591,7 +607,7 @@ func TestHandler_GetChallenge(t *testing.T) {
},
"fail/nil-jwk": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
ctx = context.WithValue(ctx, jwkContextKey, nil)
@ -611,7 +627,7 @@ func TestHandler_GetChallenge(t *testing.T) {
},
"fail/validate-challenge-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
_jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
@ -639,8 +655,8 @@ func TestHandler_GetChallenge(t *testing.T) {
return acme.NewErrorISE("force")
},
},
vco: &acme.ValidateChallengeOptions{
HTTPGet: func(string) (*http.Response, error) {
vc: &mockClient{
get: func(string) (*http.Response, error) {
return nil, errors.New("force")
},
},
@ -651,14 +667,13 @@ func TestHandler_GetChallenge(t *testing.T) {
},
"ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
_jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
_pub := _jwk.Public()
ctx = context.WithValue(ctx, jwkContextKey, &_pub)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
db: &acme.MockDB{
@ -690,8 +705,8 @@ func TestHandler_GetChallenge(t *testing.T) {
URL: u,
Error: acme.NewError(acme.ErrorConnectionType, "force"),
},
vco: &acme.ValidateChallengeOptions{
HTTPGet: func(string) (*http.Response, error) {
vc: &mockClient{
get: func(string) (*http.Response, error) {
return nil, errors.New("force")
},
},
@ -703,11 +718,11 @@ func TestHandler_GetChallenge(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco}
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
h.GetChallenge(w, req)
GetChallenge(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)

@ -9,7 +9,6 @@ import (
"net/url"
"strings"
"github.com/go-chi/chi"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
@ -31,39 +30,11 @@ func logNonce(w http.ResponseWriter, nonce string) {
}
}
// baseURLFromRequest determines the base URL which should be used for
// constructing link URLs in e.g. the ACME directory result by taking the
// request Host into consideration.
//
// If the Request.Host is an empty string, we return an empty string, to
// indicate that the configured URL values should be used instead. If this
// function returns a non-empty result, then this should be used in
// constructing ACME link URLs.
func baseURLFromRequest(r *http.Request) *url.URL {
// NOTE: See https://github.com/letsencrypt/boulder/blob/master/web/relative.go
// for an implementation that allows HTTP requests using the x-forwarded-proto
// header.
if r.Host == "" {
return nil
}
return &url.URL{Scheme: "https", Host: r.Host}
}
// baseURLFromRequest is a middleware that extracts and caches the baseURL
// from the request.
// E.g. https://ca.smallstep.com/
func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), baseURLContextKey, baseURLFromRequest(r))
next(w, r.WithContext(ctx))
}
}
// addNonce is a middleware that adds a nonce to the response header.
func (h *Handler) addNonce(next nextHTTP) nextHTTP {
func addNonce(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
nonce, err := h.db.CreateNonce(r.Context())
db := acme.MustDatabaseFromContext(r.Context())
nonce, err := db.CreateNonce(r.Context())
if err != nil {
render.Error(w, err)
return
@ -77,25 +48,31 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP {
// addDirLink is a middleware that adds a 'Link' response reader with the
// directory index url.
func (h *Handler) addDirLink(next nextHTTP) nextHTTP {
func addDirLink(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Link", link(h.linker.GetLink(r.Context(), DirectoryLinkType), "index"))
ctx := r.Context()
linker := acme.MustLinkerFromContext(ctx)
w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index"))
next(w, r)
}
}
// verifyContentType is a middleware that verifies that content type is
// application/jose+json.
func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
func verifyContentType(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
var expected []string
p, err := provisionerFromContext(r.Context())
if err != nil {
render.Error(w, err)
return
}
u := url.URL{Path: h.linker.GetUnescapedPathSuffix(CertificateLinkType, p.GetName(), "")}
u := &url.URL{
Path: acme.GetUnescapedPathSuffix(acme.CertificateLinkType, p.GetName(), ""),
}
var expected []string
if strings.Contains(r.URL.String(), u.EscapedPath()) {
// GET /certificate requests allow a greater range of content types.
expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"}
@ -117,7 +94,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
}
// parseJWS is a middleware that parses a request body into a JSONWebSignature struct.
func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
func parseJWS(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
@ -142,17 +119,19 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
// The JWS Unprotected Header [RFC7515] MUST NOT be used
// The JWS Payload MUST NOT be detached
// The JWS Protected Header MUST include the following fields:
// * “alg” (Algorithm)
// * This field MUST NOT contain “none” or a Message Authentication Code
// (MAC) algorithm (e.g. one in which the algorithm registry description
// mentions MAC/HMAC).
// * “nonce” (defined in Section 6.5)
// * “url” (defined in Section 6.4)
// * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste>
func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
// - “alg” (Algorithm).
// This field MUST NOT contain “none” or a Message Authentication Code
// (MAC) algorithm (e.g. one in which the algorithm registry description
// mentions MAC/HMAC).
// - “nonce” (defined in Section 6.5)
// - “url” (defined in Section 6.4)
// - Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste>
func validateJWS(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
jws, err := jwsFromContext(r.Context())
db := acme.MustDatabaseFromContext(ctx)
jws, err := jwsFromContext(ctx)
if err != nil {
render.Error(w, err)
return
@ -202,7 +181,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
}
// Check the validity/freshness of the Nonce.
if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
if err := db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
render.Error(w, err)
return
}
@ -235,10 +214,12 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
// extractJWK is a middleware that extracts the JWK from the JWS and saves it
// in the context. Make sure to parse and validate the JWS before running this
// middleware.
func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
func extractJWK(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
jws, err := jwsFromContext(r.Context())
db := acme.MustDatabaseFromContext(ctx)
jws, err := jwsFromContext(ctx)
if err != nil {
render.Error(w, err)
return
@ -264,7 +245,7 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
ctx = context.WithValue(ctx, jwkContextKey, jwk)
// Get Account OR continue to generate a new one OR continue Revoke with certificate private key
acc, err := h.db.GetAccountByKeyID(ctx, jwk.KeyID)
acc, err := db.GetAccountByKeyID(ctx, jwk.KeyID)
switch {
case errors.Is(err, acme.ErrNotFound):
// For NewAccount and Revoke requests ...
@ -283,63 +264,44 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
}
}
// lookupProvisioner loads the provisioner associated with the request.
// Responds 404 if the provisioner does not exist.
func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
nameEscaped := chi.URLParam(r, "provisionerID")
name, err := url.PathUnescape(nameEscaped)
if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
return
}
p, err := h.ca.LoadProvisionerByName(name)
if err != nil {
render.Error(w, err)
return
}
acmeProv, ok := p.(*provisioner.ACME)
if !ok {
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
return
}
ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv))
next(w, r.WithContext(ctx))
}
}
// checkPrerequisites checks if all prerequisites for serving ACME
// are met by the CA configuration.
func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP {
func checkPrerequisites(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ok, err := h.prerequisitesChecker(ctx)
if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
return
}
if !ok {
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
return
// If the function is not set assume that all prerequisites are met.
checkFunc, ok := acme.PrerequisitesCheckerFromContext(ctx)
if ok {
ok, err := checkFunc(ctx)
if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
return
}
if !ok {
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
return
}
}
next(w, r.WithContext(ctx))
next(w, r)
}
}
// lookupJWK loads the JWK associated with the acme account referenced by the
// kid parameter of the signed payload.
// Make sure to parse and validate the JWS before running this middleware.
func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
func lookupJWK(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
jws, err := jwsFromContext(ctx)
if err != nil {
render.Error(w, err)
return
}
kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "")
kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "")
kid := jws.Signatures[0].Protected.KeyID
if !strings.HasPrefix(kid, kidPrefix) {
render.Error(w, acme.NewError(acme.ErrorMalformedType,
@ -349,7 +311,7 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
}
accID := strings.TrimPrefix(kid, kidPrefix)
acc, err := h.db.GetAccount(ctx, accID)
acc, err := db.GetAccount(ctx, accID)
switch {
case nosql.IsErrNotFound(err):
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
@ -372,7 +334,7 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
// extractOrLookupJWK forwards handling to either extractJWK or
// lookupJWK based on the presence of a JWK or a KID, respectively.
func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP {
func extractOrLookupJWK(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
jws, err := jwsFromContext(ctx)
@ -385,13 +347,13 @@ func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP {
// and it can be used to check if a JWK exists. This flow is used when the ACME client
// signed the payload with a certificate private key.
if canExtractJWKFrom(jws) {
h.extractJWK(next)(w, r)
extractJWK(next)(w, r)
return
}
// default to looking up the JWK based on KeyID. This flow is used when the ACME client
// signed the payload with an account private key.
h.lookupJWK(next)(w, r)
lookupJWK(next)(w, r)
}
}
@ -408,7 +370,7 @@ func canExtractJWKFrom(jws *jose.JSONWebSignature) bool {
// verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context.
// Make sure to parse and validate the JWS before running this middleware.
func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
func verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
jws, err := jwsFromContext(ctx)
@ -440,7 +402,7 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
}
// isPostAsGet asserts that the request is a PostAsGet (empty JWS payload).
func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP {
func isPostAsGet(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
payload, err := payloadFromContext(r.Context())
if err != nil {
@ -462,16 +424,12 @@ type ContextKey string
const (
// accContextKey account key
accContextKey = ContextKey("acc")
// baseURLContextKey baseURL key
baseURLContextKey = ContextKey("baseURL")
// jwsContextKey jws key
jwsContextKey = ContextKey("jws")
// jwkContextKey jwk key
jwkContextKey = ContextKey("jwk")
// payloadContextKey payload key
payloadContextKey = ContextKey("payload")
// provisionerContextKey provisioner key
provisionerContextKey = ContextKey("provisioner")
)
// accountFromContext searches the context for an ACME account. Returns the
@ -484,15 +442,6 @@ func accountFromContext(ctx context.Context) (*acme.Account, error) {
return val, nil
}
// baseURLFromContext returns the baseURL if one is stored in the context.
func baseURLFromContext(ctx context.Context) *url.URL {
val, ok := ctx.Value(baseURLContextKey).(*url.URL)
if !ok || val == nil {
return nil
}
return val
}
// jwkFromContext searches the context for a JWK. Returns the JWK or an error.
func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) {
val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey)
@ -514,29 +463,26 @@ func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) {
// provisionerFromContext searches the context for a provisioner. Returns the
// provisioner or an error.
func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) {
val := ctx.Value(provisionerContextKey)
if val == nil {
p, ok := acme.ProvisionerFromContext(ctx)
if !ok || p == nil {
return nil, acme.NewErrorISE("provisioner expected in request context")
}
pval, ok := val.(acme.Provisioner)
if !ok || pval == nil {
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
}
return pval, nil
return p, nil
}
// acmeProvisionerFromContext searches the context for an ACME provisioner. Returns
// pointer to an ACME provisioner or an error.
func acmeProvisionerFromContext(ctx context.Context) (*provisioner.ACME, error) {
prov, err := provisionerFromContext(ctx)
p, err := provisionerFromContext(ctx)
if err != nil {
return nil, err
}
acmeProv, ok := prov.(*provisioner.ACME)
if !ok || acmeProv == nil {
ap, ok := p.(*provisioner.ACME)
if !ok {
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
}
return acmeProv, nil
return ap, nil
}
// payloadFromContext searches the context for a payload. Returns the payload

@ -27,83 +27,18 @@ func testNext(w http.ResponseWriter, r *http.Request) {
w.Write(testBody)
}
func Test_baseURLFromRequest(t *testing.T) {
tests := []struct {
name string
targetURL string
expectedResult *url.URL
requestPreparer func(*http.Request)
}{
{
"HTTPS host pass-through failed.",
"https://my.dummy.host",
&url.URL{Scheme: "https", Host: "my.dummy.host"},
nil,
},
{
"Port pass-through failed",
"https://host.with.port:8080",
&url.URL{Scheme: "https", Host: "host.with.port:8080"},
nil,
},
{
"Explicit host from Request.Host was not used.",
"https://some.target.host:8080",
&url.URL{Scheme: "https", Host: "proxied.host"},
func(r *http.Request) {
r.Host = "proxied.host"
},
},
{
"Missing Request.Host value did not result in empty string result.",
"https://some.host",
nil,
func(r *http.Request) {
r.Host = ""
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
request := httptest.NewRequest("GET", tc.targetURL, nil)
if tc.requestPreparer != nil {
tc.requestPreparer(request)
}
result := baseURLFromRequest(request)
if result == nil || tc.expectedResult == nil {
assert.Equals(t, result, tc.expectedResult)
} else if result.String() != tc.expectedResult.String() {
t.Errorf("Expected %q, but got %q", tc.expectedResult.String(), result.String())
}
})
}
}
func TestHandler_baseURLFromRequest(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest("GET", "/foo", nil)
req.Host = "test.ca.smallstep.com:8080"
w := httptest.NewRecorder()
next := func(w http.ResponseWriter, r *http.Request) {
bu := baseURLFromContext(r.Context())
if assert.NotNil(t, bu) {
assert.Equals(t, bu.Host, "test.ca.smallstep.com:8080")
assert.Equals(t, bu.Scheme, "https")
func newBaseContext(ctx context.Context, args ...interface{}) context.Context {
for _, a := range args {
switch v := a.(type) {
case acme.DB:
ctx = acme.NewDatabaseContext(ctx, v)
case acme.Linker:
ctx = acme.NewLinkerContext(ctx, v)
case acme.PrerequisitesChecker:
ctx = acme.NewPrerequisitesCheckerContext(ctx, v)
}
}
h.baseURLFromRequest(next)(w, req)
req = httptest.NewRequest("GET", "/foo", nil)
req.Host = ""
next = func(w http.ResponseWriter, r *http.Request) {
assert.Equals(t, baseURLFromContext(r.Context()), nil)
}
h.baseURLFromRequest(next)(w, req)
return ctx
}
func TestHandler_addNonce(t *testing.T) {
@ -139,10 +74,10 @@ func TestHandler_addNonce(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db}
req := httptest.NewRequest("GET", u, nil)
ctx := newBaseContext(context.Background(), tc.db)
req := httptest.NewRequest("GET", u, nil).WithContext(ctx)
w := httptest.NewRecorder()
h.addNonce(testNext)(w, req)
addNonce(testNext)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -175,17 +110,15 @@ func TestHandler_addDirLink(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
type test struct {
link string
linker Linker
statusCode int
ctx context.Context
err *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"ok": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = acme.NewLinkerContext(ctx, acme.NewLinker("test.ca.smallstep.com", "acme"))
return test{
linker: NewLinker("dns", "acme"),
ctx: ctx,
link: fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName),
statusCode: 200,
@ -195,11 +128,10 @@ func TestHandler_addDirLink(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{linker: tc.linker}
req := httptest.NewRequest("GET", "/foo", nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.addDirLink(testNext)(w, req)
addDirLink(testNext)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -231,7 +163,6 @@ func TestHandler_verifyContentType(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
u := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName)
type test struct {
h Handler
ctx context.Context
contentType string
err *acme.Error
@ -241,9 +172,6 @@ func TestHandler_verifyContentType(t *testing.T) {
var tests = map[string]func(t *testing.T) test{
"fail/provisioner-not-set": func(t *testing.T) test {
return test{
h: Handler{
linker: NewLinker("dns", "acme"),
},
url: u,
ctx: context.Background(),
contentType: "foo",
@ -253,11 +181,8 @@ func TestHandler_verifyContentType(t *testing.T) {
},
"fail/general-bad-content-type": func(t *testing.T) test {
return test{
h: Handler{
linker: NewLinker("dns", "acme"),
},
url: u,
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
ctx: acme.NewProvisionerContext(context.Background(), prov),
contentType: "foo",
statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json], but got foo"),
@ -265,10 +190,7 @@ func TestHandler_verifyContentType(t *testing.T) {
},
"fail/certificate-bad-content-type": func(t *testing.T) test {
return test{
h: Handler{
linker: NewLinker("dns", "acme"),
},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
ctx: acme.NewProvisionerContext(context.Background(), prov),
contentType: "foo",
statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo"),
@ -276,40 +198,28 @@ func TestHandler_verifyContentType(t *testing.T) {
},
"ok": func(t *testing.T) test {
return test{
h: Handler{
linker: NewLinker("dns", "acme"),
},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
ctx: acme.NewProvisionerContext(context.Background(), prov),
contentType: "application/jose+json",
statusCode: 200,
}
},
"ok/certificate/pkix-cert": func(t *testing.T) test {
return test{
h: Handler{
linker: NewLinker("dns", "acme"),
},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
ctx: acme.NewProvisionerContext(context.Background(), prov),
contentType: "application/pkix-cert",
statusCode: 200,
}
},
"ok/certificate/jose+json": func(t *testing.T) test {
return test{
h: Handler{
linker: NewLinker("dns", "acme"),
},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
ctx: acme.NewProvisionerContext(context.Background(), prov),
contentType: "application/jose+json",
statusCode: 200,
}
},
"ok/certificate/pkcs7-mime": func(t *testing.T) test {
return test{
h: Handler{
linker: NewLinker("dns", "acme"),
},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
ctx: acme.NewProvisionerContext(context.Background(), prov),
contentType: "application/pkcs7-mime",
statusCode: 200,
}
@ -326,7 +236,7 @@ func TestHandler_verifyContentType(t *testing.T) {
req = req.WithContext(tc.ctx)
req.Header.Add("Content-Type", tc.contentType)
w := httptest.NewRecorder()
tc.h.verifyContentType(testNext)(w, req)
verifyContentType(testNext)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -390,11 +300,11 @@ func TestHandler_isPostAsGet(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{}
// h := &Handler{}
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.isPostAsGet(testNext)(w, req)
isPostAsGet(testNext)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -481,10 +391,10 @@ func TestHandler_parseJWS(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{}
// h := &Handler{}
req := httptest.NewRequest("GET", u, tc.body)
w := httptest.NewRecorder()
h.parseJWS(tc.next)(w, req)
parseJWS(tc.next)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -679,11 +589,11 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{}
// h := &Handler{}
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.verifyAndExtractJWSPayload(tc.next)(w, req)
verifyAndExtractJWSPayload(tc.next)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -733,7 +643,7 @@ func TestHandler_lookupJWK(t *testing.T) {
parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err)
type test struct {
linker Linker
linker acme.Linker
db acme.DB
ctx context.Context
next func(http.ResponseWriter, *http.Request)
@ -743,15 +653,19 @@ func TestHandler_lookupJWK(t *testing.T) {
var tests = map[string]func(t *testing.T) test{
"fail/no-jws": func(t *testing.T) test {
return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
db: &acme.MockDB{},
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
ctx: acme.NewProvisionerContext(context.Background(), prov),
statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"),
}
},
"fail/nil-jws": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, nil)
return test{
db: &acme.MockDB{},
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"),
@ -765,11 +679,11 @@ func TestHandler_lookupJWK(t *testing.T) {
assert.FatalError(t, err)
_jws, err := _signer.Sign([]byte("baz"))
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, _jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
linker: NewLinker("dns", "acme"),
db: &acme.MockDB{},
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
ctx: ctx,
statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got ", prefix),
@ -789,22 +703,21 @@ func TestHandler_lookupJWK(t *testing.T) {
assert.FatalError(t, err)
_parsed, err := jose.ParseJWS(_raw)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, _parsed)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
linker: NewLinker("dns", "acme"),
db: &acme.MockDB{},
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
ctx: ctx,
statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got foo", prefix),
}
},
"fail/account-not-found": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
linker: NewLinker("dns", "acme"),
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
assert.Equals(t, accID, accID)
@ -817,11 +730,10 @@ func TestHandler_lookupJWK(t *testing.T) {
}
},
"fail/GetAccount-error": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
linker: NewLinker("dns", "acme"),
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
assert.Equals(t, id, accID)
@ -835,11 +747,10 @@ func TestHandler_lookupJWK(t *testing.T) {
},
"fail/account-not-valid": func(t *testing.T) test {
acc := &acme.Account{Status: "deactivated"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
linker: NewLinker("dns", "acme"),
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
assert.Equals(t, id, accID)
@ -853,11 +764,10 @@ func TestHandler_lookupJWK(t *testing.T) {
},
"ok": func(t *testing.T) test {
acc := &acme.Account{Status: "valid", Key: jwk}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
linker: NewLinker("dns", "acme"),
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
assert.Equals(t, id, accID)
@ -881,11 +791,11 @@ func TestHandler_lookupJWK(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: tc.linker}
ctx := newBaseContext(tc.ctx, tc.db, tc.linker)
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
h.lookupJWK(tc.next)(w, req)
lookupJWK(tc.next)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -945,15 +855,17 @@ func TestHandler_extractJWK(t *testing.T) {
var tests = map[string]func(t *testing.T) test{
"fail/no-jws": func(t *testing.T) test {
return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
db: &acme.MockDB{},
ctx: acme.NewProvisionerContext(context.Background(), prov),
statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"),
}
},
"fail/nil-jws": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, nil)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"),
@ -969,9 +881,10 @@ func TestHandler_extractJWK(t *testing.T) {
},
},
}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, _jws)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"),
@ -987,16 +900,17 @@ func TestHandler_extractJWK(t *testing.T) {
},
},
}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, _jws)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"),
}
},
"fail/GetAccountByKey-error": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
ctx: ctx,
@ -1012,7 +926,7 @@ func TestHandler_extractJWK(t *testing.T) {
},
"fail/account-not-valid": func(t *testing.T) test {
acc := &acme.Account{Status: "deactivated"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
ctx: ctx,
@ -1028,7 +942,7 @@ func TestHandler_extractJWK(t *testing.T) {
},
"ok": func(t *testing.T) test {
acc := &acme.Account{Status: "valid"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
ctx: ctx,
@ -1051,7 +965,7 @@ func TestHandler_extractJWK(t *testing.T) {
}
},
"ok/no-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
ctx: ctx,
@ -1077,11 +991,11 @@ func TestHandler_extractJWK(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db}
ctx := newBaseContext(tc.ctx, tc.db)
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
h.extractJWK(tc.next)(w, req)
extractJWK(tc.next)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -1118,6 +1032,7 @@ func TestHandler_validateJWS(t *testing.T) {
var tests = map[string]func(t *testing.T) test{
"fail/no-jws": func(t *testing.T) test {
return test{
db: &acme.MockDB{},
ctx: context.Background(),
statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"),
@ -1125,6 +1040,7 @@ func TestHandler_validateJWS(t *testing.T) {
},
"fail/nil-jws": func(t *testing.T) test {
return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), jwsContextKey, nil),
statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"),
@ -1132,6 +1048,7 @@ func TestHandler_validateJWS(t *testing.T) {
},
"fail/no-signature": func(t *testing.T) test {
return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), jwsContextKey, &jose.JSONWebSignature{}),
statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"),
@ -1145,6 +1062,7 @@ func TestHandler_validateJWS(t *testing.T) {
},
}
return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"),
@ -1157,6 +1075,7 @@ func TestHandler_validateJWS(t *testing.T) {
},
}
return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"),
@ -1169,6 +1088,7 @@ func TestHandler_validateJWS(t *testing.T) {
},
}
return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400,
err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: none"),
@ -1181,6 +1101,7 @@ func TestHandler_validateJWS(t *testing.T) {
},
}
return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400,
err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", jose.HS256),
@ -1444,11 +1365,11 @@ func TestHandler_validateJWS(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db}
ctx := newBaseContext(tc.ctx, tc.db)
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
h.validateJWS(tc.next)(w, req)
validateJWS(tc.next)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -1542,7 +1463,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
u := "https://ca.smallstep.com/acme/account"
type test struct {
db acme.DB
linker Linker
linker acme.Linker
statusCode int
ctx context.Context
err *acme.Error
@ -1570,7 +1491,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err)
return test{
linker: NewLinker("dns", "acme"),
linker: acme.NewLinker("dns", "acme"),
db: &acme.MockDB{
MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) {
assert.Equals(t, kid, pub.KeyID)
@ -1606,11 +1527,10 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err)
acc := &acme.Account{ID: "accID", Key: jwk, Status: "valid"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
linker: NewLinker("test.ca.smallstep.com", "acme"),
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
assert.Equals(t, accID, acc.ID)
@ -1628,11 +1548,11 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: tc.linker}
ctx := newBaseContext(tc.ctx, tc.db, tc.linker)
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
h.extractOrLookupJWK(tc.next)(w, req)
extractOrLookupJWK(tc.next)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -1664,7 +1584,7 @@ func TestHandler_checkPrerequisites(t *testing.T) {
u := fmt.Sprintf("%s/acme/%s/account/1234",
baseURL, provName)
type test struct {
linker Linker
linker acme.Linker
ctx context.Context
prerequisitesChecker func(context.Context) (bool, error)
next func(http.ResponseWriter, *http.Request)
@ -1673,10 +1593,9 @@ func TestHandler_checkPrerequisites(t *testing.T) {
}
var tests = map[string]func(t *testing.T) test{
"fail/error": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx := acme.NewProvisionerContext(context.Background(), prov)
return test{
linker: NewLinker("dns", "acme"),
linker: acme.NewLinker("dns", "acme"),
ctx: ctx,
prerequisitesChecker: func(context.Context) (bool, error) { return false, errors.New("force") },
next: func(w http.ResponseWriter, r *http.Request) {
@ -1687,10 +1606,9 @@ func TestHandler_checkPrerequisites(t *testing.T) {
}
},
"fail/prerequisites-nok": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx := acme.NewProvisionerContext(context.Background(), prov)
return test{
linker: NewLinker("dns", "acme"),
linker: acme.NewLinker("dns", "acme"),
ctx: ctx,
prerequisitesChecker: func(context.Context) (bool, error) { return false, nil },
next: func(w http.ResponseWriter, r *http.Request) {
@ -1701,10 +1619,9 @@ func TestHandler_checkPrerequisites(t *testing.T) {
}
},
"ok": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx := acme.NewProvisionerContext(context.Background(), prov)
return test{
linker: NewLinker("dns", "acme"),
linker: acme.NewLinker("dns", "acme"),
ctx: ctx,
prerequisitesChecker: func(context.Context) (bool, error) { return true, nil },
next: func(w http.ResponseWriter, r *http.Request) {
@ -1717,11 +1634,11 @@ func TestHandler_checkPrerequisites(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: nil, linker: tc.linker, prerequisitesChecker: tc.prerequisitesChecker}
ctx := acme.NewPrerequisitesCheckerContext(tc.ctx, tc.prerequisitesChecker)
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
h.checkPrerequisites(tc.next)(w, req)
checkPrerequisites(tc.next)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)

@ -16,6 +16,8 @@ import (
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/authority/provisioner"
)
// NewOrderRequest represents the body for a NewOrder request.
@ -37,6 +39,8 @@ func (n *NewOrderRequest) Validate() error {
if id.Type == acme.IP && net.ParseIP(id.Value) == nil {
return acme.NewError(acme.ErrorMalformedType, "invalid IP address: %s", id.Value)
}
// TODO(hs): add some validations for DNS domains?
// TODO(hs): combine the errors from this with allow/deny policy, like example error in https://datatracker.ietf.org/doc/html/rfc8555#section-6.7.1
}
return nil
}
@ -50,7 +54,13 @@ type FinalizeRequest struct {
// Validate validates a finalize request body.
func (f *FinalizeRequest) Validate() error {
var err error
csrBytes, err := base64.RawURLEncoding.DecodeString(f.CSR)
// RFC 8555 isn't 100% conclusive about using raw base64-url encoding for the
// CSR specifically, instead of "normal" base64-url encoding (incl. padding).
// By trimming the padding from CSRs submitted by ACME clients that use
// base64-url encoding instead of raw base64-url encoding, these are also
// supported. This was reported in https://github.com/smallstep/certificates/issues/939
// to be the case for a Synology DSM NAS system.
csrBytes, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(f.CSR, "="))
if err != nil {
return acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding csr")
}
@ -68,8 +78,12 @@ var defaultOrderExpiry = time.Hour * 24
var defaultOrderBackdate = time.Minute
// NewOrder ACME api for creating a new order.
func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
func NewOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ca := mustAuthority(ctx)
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx)
if err != nil {
render.Error(w, err)
@ -85,6 +99,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
render.Error(w, err)
return
}
var nor NewOrderRequest
if err := json.Unmarshal(payload.value, &nor); err != nil {
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
@ -97,6 +112,48 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
return
}
// TODO(hs): gather all errors, so that we can build one response with ACME subproblems
// include the nor.Validate() error here too, like in the example in the ACME RFC?
acmeProv, err := acmeProvisionerFromContext(ctx)
if err != nil {
render.Error(w, err)
return
}
var eak *acme.ExternalAccountKey
if acmeProv.RequireEAB {
if eak, err = db.GetExternalAccountKeyByAccountID(ctx, prov.GetID(), acc.ID); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving external account binding key"))
return
}
}
acmePolicy, err := newACMEPolicyEngine(eak)
if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error creating ACME policy engine"))
return
}
for _, identifier := range nor.Identifiers {
// evaluate the ACME account level policy
if err = isIdentifierAllowed(acmePolicy, identifier); err != nil {
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
return
}
// evaluate the provisioner level policy
orderIdentifier := provisioner.ACMEIdentifier{Type: provisioner.ACMEIdentifierType(identifier.Type), Value: identifier.Value}
if err = prov.AuthorizeOrderIdentifier(ctx, orderIdentifier); err != nil {
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
return
}
// evaluate the authority level policy
if err = ca.AreSANsAllowed(ctx, []string{identifier.Value}); err != nil {
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
return
}
}
now := clock.Now()
// New order.
o := &acme.Order{
@ -117,7 +174,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
ExpiresAt: o.ExpiresAt,
Status: acme.StatusPending,
}
if err := h.newAuthorization(ctx, az); err != nil {
if err := newAuthorization(ctx, az); err != nil {
render.Error(w, err)
return
}
@ -136,18 +193,32 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
o.NotBefore = o.NotBefore.Add(-defaultOrderBackdate)
}
if err := h.db.CreateOrder(ctx, o); err != nil {
if err := db.CreateOrder(ctx, o); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error creating order"))
return
}
h.linker.LinkOrder(ctx, o)
linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
render.JSONStatus(w, o, http.StatusCreated)
}
func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error {
func isIdentifierAllowed(acmePolicy policy.X509Policy, identifier acme.Identifier) error {
if acmePolicy == nil {
return nil
}
return acmePolicy.AreSANsAllowed([]string{identifier.Value})
}
func newACMEPolicyEngine(eak *acme.ExternalAccountKey) (policy.X509Policy, error) {
if eak == nil {
return nil, nil
}
return policy.NewX509PolicyEngine(eak.Policy)
}
func newAuthorization(ctx context.Context, az *acme.Authorization) error {
if strings.HasPrefix(az.Identifier.Value, "*.") {
az.Wildcard = true
az.Identifier = acme.Identifier{
@ -163,6 +234,8 @@ func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization)
if err != nil {
return acme.WrapErrorISE(err, "error generating random alphanumeric ID")
}
db := acme.MustDatabaseFromContext(ctx)
az.Challenges = make([]*acme.Challenge, len(chTypes))
for i, typ := range chTypes {
ch := &acme.Challenge{
@ -172,20 +245,23 @@ func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization)
Token: az.Token,
Status: acme.StatusPending,
}
if err := h.db.CreateChallenge(ctx, ch); err != nil {
if err := db.CreateChallenge(ctx, ch); err != nil {
return acme.WrapErrorISE(err, "error creating challenge")
}
az.Challenges[i] = ch
}
if err = h.db.CreateAuthorization(ctx, az); err != nil {
if err = db.CreateAuthorization(ctx, az); err != nil {
return acme.WrapErrorISE(err, "error creating authorization")
}
return nil
}
// GetOrder ACME api for retrieving an order.
func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
func GetOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx)
if err != nil {
render.Error(w, err)
@ -196,7 +272,8 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
render.Error(w, err)
return
}
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
return
@ -211,20 +288,23 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
return
}
if err = o.UpdateStatus(ctx, h.db); err != nil {
if err = o.UpdateStatus(ctx, db); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error updating order status"))
return
}
h.linker.LinkOrder(ctx, o)
linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
render.JSON(w, o)
}
// FinalizeOrder attemptst to finalize an order and create a certificate.
func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
// FinalizeOrder attempts to finalize an order and create a certificate.
func FinalizeOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx)
if err != nil {
render.Error(w, err)
@ -251,7 +331,7 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
return
}
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
return
@ -266,14 +346,16 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
return
}
if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil {
ca := mustAuthority(ctx)
if err = o.Finalize(ctx, db, fr.csr, ca, prov); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error finalizing order"))
return
}
h.linker.LinkOrder(ctx, o)
linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
render.JSON(w, o)
}

@ -16,9 +16,13 @@ import (
"github.com/go-chi/chi"
"github.com/pkg/errors"
"go.step.sm/crypto/pemutil"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"go.step.sm/crypto/pemutil"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/authority/provisioner"
)
func TestNewOrderRequest_Validate(t *testing.T) {
@ -206,6 +210,13 @@ func TestFinalizeRequestValidate(t *testing.T) {
},
}
},
"ok/padding": func(t *testing.T) test {
return test{
fr: &FinalizeRequest{
CSR: base64.RawURLEncoding.EncodeToString(csr.Raw) + "==", // add intentional padding
},
}
},
}
for name, run := range tests {
tc := run(t)
@ -276,15 +287,17 @@ func TestHandler_GetOrder(t *testing.T) {
var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test {
return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
db: &acme.MockDB{},
ctx: acme.NewProvisionerContext(context.Background(), prov),
statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
}
},
"fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, nil)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -294,6 +307,7 @@ func TestHandler_GetOrder(t *testing.T) {
acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), accContextKey, acc)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"),
@ -301,9 +315,10 @@ func TestHandler_GetOrder(t *testing.T) {
},
"fail/nil-provisioner": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
ctx := acme.NewProvisionerContext(context.Background(), nil)
ctx = context.WithValue(ctx, accContextKey, acc)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"),
@ -311,7 +326,7 @@ func TestHandler_GetOrder(t *testing.T) {
},
"fail/db.GetOrder-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
@ -325,7 +340,7 @@ func TestHandler_GetOrder(t *testing.T) {
},
"fail/account-id-mismatch": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
@ -341,7 +356,7 @@ func TestHandler_GetOrder(t *testing.T) {
},
"fail/provisioner-id-mismatch": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
@ -357,7 +372,7 @@ func TestHandler_GetOrder(t *testing.T) {
},
"fail/order-update-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
@ -381,10 +396,9 @@ func TestHandler_GetOrder(t *testing.T) {
},
"ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
db: &acme.MockDB{
MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) {
@ -421,11 +435,11 @@ func TestHandler_GetOrder(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
h.GetOrder(w, req)
GetOrder(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -636,8 +650,8 @@ func TestHandler_newAuthorization(t *testing.T) {
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
h := &Handler{db: tc.db}
if err := h.newAuthorization(context.Background(), tc.az); err != nil {
ctx := newBaseContext(context.Background(), tc.db)
if err := newAuthorization(ctx, tc.az); err != nil {
if assert.NotNil(t, tc.err) {
switch k := err.(type) {
case *acme.Error:
@ -667,6 +681,7 @@ func TestHandler_NewOrder(t *testing.T) {
baseURL.String(), escProvName)
type test struct {
ca acme.CertificateAuthority
db acme.DB
ctx context.Context
nor *NewOrderRequest
@ -677,15 +692,17 @@ func TestHandler_NewOrder(t *testing.T) {
var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test {
return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
db: &acme.MockDB{},
ctx: acme.NewProvisionerContext(context.Background(), prov),
statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
}
},
"fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, nil)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -695,6 +712,7 @@ func TestHandler_NewOrder(t *testing.T) {
acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), accContextKey, acc)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"),
@ -702,9 +720,10 @@ func TestHandler_NewOrder(t *testing.T) {
},
"fail/nil-provisioner": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"),
@ -713,8 +732,9 @@ func TestHandler_NewOrder(t *testing.T) {
"fail/no-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), accContextKey, acc)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = acme.NewProvisionerContext(ctx, prov)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("payload does not exist"),
@ -722,21 +742,23 @@ func TestHandler_NewOrder(t *testing.T) {
},
"fail/nil-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("paylod does not exist"),
err: acme.NewErrorISE("payload does not exist"),
}
},
"fail/unmarshal-payload-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-order request payload: unexpected end of JSON input"),
@ -747,15 +769,232 @@ func TestHandler_NewOrder(t *testing.T) {
fr := &NewOrderRequest{}
b, err := json.Marshal(fr)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"),
}
},
"fail/acmeProvisionerFromContext-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
fr := &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "zap.internal"},
},
}
b, err := json.Marshal(fr)
assert.FatalError(t, err)
ctx := acme.NewProvisionerContext(context.Background(), &acme.MockProvisioner{})
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 500,
ca: &mockCA{},
db: &acme.MockDB{
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, errors.New("force")
},
},
err: acme.NewErrorISE("error retrieving external account binding key: force"),
}
},
"fail/db.GetExternalAccountKeyByAccountID-error": func(t *testing.T) test {
acmeProv := newACMEProv(t)
acmeProv.RequireEAB = true
acc := &acme.Account{ID: "accID"}
fr := &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "zap.internal"},
},
}
b, err := json.Marshal(fr)
assert.FatalError(t, err)
ctx := acme.NewProvisionerContext(context.Background(), acmeProv)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 500,
ca: &mockCA{},
db: &acme.MockDB{
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, errors.New("force")
},
},
err: acme.NewErrorISE("error retrieving external account binding key: force"),
}
},
"fail/newACMEPolicyEngine-error": func(t *testing.T) test {
acmeProv := newACMEProv(t)
acmeProv.RequireEAB = true
acc := &acme.Account{ID: "accID"}
fr := &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "zap.internal"},
},
}
b, err := json.Marshal(fr)
assert.FatalError(t, err)
ctx := acme.NewProvisionerContext(context.Background(), acmeProv)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 500,
ca: &mockCA{},
db: &acme.MockDB{
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return &acme.ExternalAccountKey{
Policy: &acme.Policy{
X509: acme.X509Policy{
Allowed: acme.PolicyNames{
DNSNames: []string{"**.local"},
},
},
},
}, nil
},
},
err: acme.NewErrorISE("error creating ACME policy engine"),
}
},
"fail/isIdentifierAllowed-error": func(t *testing.T) test {
acmeProv := newACMEProv(t)
acmeProv.RequireEAB = true
acc := &acme.Account{ID: "accID"}
fr := &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "zap.internal"},
},
}
b, err := json.Marshal(fr)
assert.FatalError(t, err)
ctx := acme.NewProvisionerContext(context.Background(), acmeProv)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 400,
ca: &mockCA{},
db: &acme.MockDB{
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return &acme.ExternalAccountKey{
Policy: &acme.Policy{
X509: acme.X509Policy{
Allowed: acme.PolicyNames{
DNSNames: []string{"*.local"},
},
},
},
}, nil
},
},
err: acme.NewError(acme.ErrorRejectedIdentifierType, "not authorized"),
}
},
"fail/prov.AuthorizeOrderIdentifier-error": func(t *testing.T) test {
options := &provisioner.Options{
X509: &provisioner.X509Options{
AllowedNames: &policy.X509NameOptions{
DNSDomains: []string{"*.local"},
},
},
}
provWithPolicy := newACMEProvWithOptions(t, options)
provWithPolicy.RequireEAB = true
acc := &acme.Account{ID: "accID"}
fr := &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "zap.internal"},
},
}
b, err := json.Marshal(fr)
assert.FatalError(t, err)
ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 400,
ca: &mockCA{},
db: &acme.MockDB{
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return &acme.ExternalAccountKey{
Policy: &acme.Policy{
X509: acme.X509Policy{
Allowed: acme.PolicyNames{
DNSNames: []string{"*.internal"},
},
},
},
}, nil
},
},
err: acme.NewError(acme.ErrorRejectedIdentifierType, "not authorized"),
}
},
"fail/ca.AreSANsAllowed-error": func(t *testing.T) test {
options := &provisioner.Options{
X509: &provisioner.X509Options{
AllowedNames: &policy.X509NameOptions{
DNSDomains: []string{"*.internal"},
},
},
}
provWithPolicy := newACMEProvWithOptions(t, options)
provWithPolicy.RequireEAB = true
acc := &acme.Account{ID: "accID"}
fr := &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "zap.internal"},
},
}
b, err := json.Marshal(fr)
assert.FatalError(t, err)
ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 400,
ca: &mockCA{
MockAreSANsallowed: func(ctx context.Context, sans []string) error {
return errors.New("force: not authorized by authority")
},
},
db: &acme.MockDB{
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return &acme.ExternalAccountKey{
Policy: &acme.Policy{
X509: acme.X509Policy{
Allowed: acme.PolicyNames{
DNSNames: []string{"*.internal"},
},
},
},
}, nil
},
},
err: acme.NewError(acme.ErrorRejectedIdentifierType, "not authorized"),
}
},
"fail/error-h.newAuthorization": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
fr := &NewOrderRequest{
@ -765,12 +1004,13 @@ func TestHandler_NewOrder(t *testing.T) {
}
b, err := json.Marshal(fr)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 500,
ca: &mockCA{},
db: &acme.MockDB{
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
assert.Equals(t, ch.AccountID, "accID")
@ -780,6 +1020,11 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, ch.Value, "zap.internal")
return errors.New("force")
},
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, nil
},
},
err: acme.NewErrorISE("error creating challenge: force"),
}
@ -793,7 +1038,7 @@ func TestHandler_NewOrder(t *testing.T) {
}
b, err := json.Marshal(fr)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
var (
@ -804,6 +1049,7 @@ func TestHandler_NewOrder(t *testing.T) {
return test{
ctx: ctx,
statusCode: 500,
ca: &mockCA{},
db: &acme.MockDB{
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
switch count {
@ -849,6 +1095,11 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
return errors.New("force")
},
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, nil
},
},
err: acme.NewErrorISE("error creating order: force"),
}
@ -863,10 +1114,9 @@ func TestHandler_NewOrder(t *testing.T) {
}
b, err := json.Marshal(nor)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
var (
ch1, ch2, ch3, ch4 **acme.Challenge
az1ID, az2ID *string
@ -876,6 +1126,7 @@ func TestHandler_NewOrder(t *testing.T) {
ctx: ctx,
statusCode: 201,
nor: nor,
ca: &mockCA{},
db: &acme.MockDB{
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
switch chCount {
@ -945,6 +1196,11 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID, *az2ID})
return nil
},
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, nil
},
},
vr: func(t *testing.T, o *acme.Order) {
now := clock.Now()
@ -978,10 +1234,9 @@ func TestHandler_NewOrder(t *testing.T) {
}
b, err := json.Marshal(nor)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
var (
ch1, ch2, ch3 **acme.Challenge
az1ID *string
@ -991,6 +1246,7 @@ func TestHandler_NewOrder(t *testing.T) {
ctx: ctx,
statusCode: 201,
nor: nor,
ca: &mockCA{},
db: &acme.MockDB{
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
switch count {
@ -1037,6 +1293,11 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
return nil
},
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, nil
},
},
vr: func(t *testing.T, o *acme.Order) {
now := clock.Now()
@ -1070,10 +1331,9 @@ func TestHandler_NewOrder(t *testing.T) {
}
b, err := json.Marshal(nor)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
var (
ch1, ch2, ch3 **acme.Challenge
az1ID *string
@ -1083,6 +1343,7 @@ func TestHandler_NewOrder(t *testing.T) {
ctx: ctx,
statusCode: 201,
nor: nor,
ca: &mockCA{},
db: &acme.MockDB{
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
switch count {
@ -1129,6 +1390,11 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
return nil
},
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, nil
},
},
vr: func(t *testing.T, o *acme.Order) {
now := clock.Now()
@ -1161,10 +1427,9 @@ func TestHandler_NewOrder(t *testing.T) {
}
b, err := json.Marshal(nor)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
var (
ch1, ch2, ch3 **acme.Challenge
az1ID *string
@ -1174,6 +1439,7 @@ func TestHandler_NewOrder(t *testing.T) {
ctx: ctx,
statusCode: 201,
nor: nor,
ca: &mockCA{},
db: &acme.MockDB{
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
switch count {
@ -1220,6 +1486,11 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
return nil
},
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, nil
},
},
vr: func(t *testing.T, o *acme.Order) {
testBufferDur := 5 * time.Second
@ -1253,10 +1524,109 @@ func TestHandler_NewOrder(t *testing.T) {
}
b, err := json.Marshal(nor)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
var (
ch1, ch2, ch3 **acme.Challenge
az1ID *string
count = 0
)
return test{
ctx: ctx,
statusCode: 201,
nor: nor,
ca: &mockCA{},
db: &acme.MockDB{
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
switch count {
case 0:
ch.ID = "dns"
assert.Equals(t, ch.Type, acme.DNS01)
ch1 = &ch
case 1:
ch.ID = "http"
assert.Equals(t, ch.Type, acme.HTTP01)
ch2 = &ch
case 2:
ch.ID = "tls"
assert.Equals(t, ch.Type, acme.TLSALPN01)
ch3 = &ch
default:
assert.FatalError(t, errors.New("test logic error"))
return errors.New("force")
}
count++
assert.Equals(t, ch.AccountID, "accID")
assert.NotEquals(t, ch.Token, "")
assert.Equals(t, ch.Status, acme.StatusPending)
assert.Equals(t, ch.Value, "zap.internal")
return nil
},
MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error {
az.ID = "az1ID"
az1ID = &az.ID
assert.Equals(t, az.AccountID, "accID")
assert.NotEquals(t, az.Token, "")
assert.Equals(t, az.Status, acme.StatusPending)
assert.Equals(t, az.Identifier, nor.Identifiers[0])
assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3})
assert.Equals(t, az.Wildcard, false)
return nil
},
MockCreateOrder: func(ctx context.Context, o *acme.Order) error {
o.ID = "ordID"
assert.Equals(t, o.AccountID, "accID")
assert.Equals(t, o.ProvisionerID, prov.GetID())
assert.Equals(t, o.Status, acme.StatusPending)
assert.Equals(t, o.Identifiers, nor.Identifiers)
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
return nil
},
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, nil
},
},
vr: func(t *testing.T, o *acme.Order) {
testBufferDur := 5 * time.Second
orderExpiry := now.Add(defaultOrderExpiry)
assert.Equals(t, o.ID, "ordID")
assert.Equals(t, o.Status, acme.StatusPending)
assert.Equals(t, o.Identifiers, nor.Identifiers)
assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)})
assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf))
assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf))
assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf))
assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf))
assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry))
assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry))
},
}
},
"ok/default-naf-nbf-with-policy": func(t *testing.T) test {
options := &provisioner.Options{
X509: &provisioner.X509Options{
AllowedNames: &policy.X509NameOptions{
DNSDomains: []string{"*.internal"},
},
},
}
provWithPolicy := newACMEProvWithOptions(t, options)
provWithPolicy.RequireEAB = true
acc := &acme.Account{ID: "accID"}
nor := &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "zap.internal"},
},
}
b, err := json.Marshal(nor)
assert.FatalError(t, err)
ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
var (
ch1, ch2, ch3 **acme.Challenge
az1ID *string
@ -1266,6 +1636,7 @@ func TestHandler_NewOrder(t *testing.T) {
ctx: ctx,
statusCode: 201,
nor: nor,
ca: &mockCA{},
db: &acme.MockDB{
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
switch count {
@ -1312,10 +1683,18 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
return nil
},
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, nil
},
},
vr: func(t *testing.T, o *acme.Order) {
now := clock.Now()
testBufferDur := 5 * time.Second
orderExpiry := now.Add(defaultOrderExpiry)
expNbf := now.Add(-defaultOrderBackdate)
expNaf := now.Add(prov.DefaultTLSCertDuration())
assert.Equals(t, o.ID, "ordID")
assert.Equals(t, o.Status, acme.StatusPending)
@ -1334,11 +1713,12 @@ func TestHandler_NewOrder(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
mockMustAuthority(t, tc.ca)
ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
h.NewOrder(w, req)
NewOrder(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -1371,6 +1751,7 @@ func TestHandler_NewOrder(t *testing.T) {
}
func TestHandler_FinalizeOrder(t *testing.T) {
mockMustAuthority(t, &mockCA{})
prov := newProv()
escProvName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
@ -1429,15 +1810,17 @@ func TestHandler_FinalizeOrder(t *testing.T) {
var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test {
return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
db: &acme.MockDB{},
ctx: acme.NewProvisionerContext(context.Background(), prov),
statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
}
},
"fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, nil)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -1447,6 +1830,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), accContextKey, acc)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"),
@ -1454,9 +1838,10 @@ func TestHandler_FinalizeOrder(t *testing.T) {
},
"fail/nil-provisioner": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"),
@ -1465,8 +1850,9 @@ func TestHandler_FinalizeOrder(t *testing.T) {
"fail/no-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), accContextKey, acc)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = acme.NewProvisionerContext(ctx, prov)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("payload does not exist"),
@ -1474,21 +1860,23 @@ func TestHandler_FinalizeOrder(t *testing.T) {
},
"fail/nil-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("paylod does not exist"),
err: acme.NewErrorISE("payload does not exist"),
}
},
"fail/unmarshal-payload-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal finalize-order request payload: unexpected end of JSON input"),
@ -1499,10 +1887,11 @@ func TestHandler_FinalizeOrder(t *testing.T) {
fr := &FinalizeRequest{}
b, err := json.Marshal(fr)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "unable to parse csr: asn1: syntax error: sequence truncated"),
@ -1511,7 +1900,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
"fail/db.GetOrder-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -1526,7 +1915,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
},
"fail/account-id-mismatch": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -1543,7 +1932,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
},
"fail/provisioner-id-mismatch": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -1560,7 +1949,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
},
"fail/order-finalize-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -1585,10 +1974,9 @@ func TestHandler_FinalizeOrder(t *testing.T) {
},
"ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
db: &acme.MockDB{
@ -1624,11 +2012,11 @@ func TestHandler_FinalizeOrder(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
h.FinalizeOrder(w, req)
FinalizeOrder(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)

@ -26,9 +26,11 @@ type revokePayload struct {
}
// RevokeCert attempts to revoke a certificate.
func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
func RevokeCert(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
jws, err := jwsFromContext(ctx)
if err != nil {
render.Error(w, err)
@ -69,7 +71,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
}
serial := certToBeRevoked.SerialNumber.String()
dbCert, err := h.db.GetCertificateBySerial(ctx, serial)
dbCert, err := db.GetCertificateBySerial(ctx, serial)
if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial"))
return
@ -87,7 +89,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
render.Error(w, err)
return
}
acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
acmeErr := isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
if acmeErr != nil {
render.Error(w, acmeErr)
return
@ -103,7 +105,8 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
}
}
hasBeenRevokedBefore, err := h.ca.IsRevoked(serial)
ca := mustAuthority(ctx)
hasBeenRevokedBefore, err := ca.IsRevoked(serial)
if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate"))
return
@ -130,14 +133,14 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
}
options := revokeOptions(serial, certToBeRevoked, reasonCode)
err = h.ca.Revoke(ctx, options)
err = ca.Revoke(ctx, options)
if err != nil {
render.Error(w, wrapRevokeErr(err))
return
}
logRevoke(w, options)
w.Header().Add("Link", link(h.linker.GetLink(ctx, DirectoryLinkType), "index"))
w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index"))
w.Write(nil)
}
@ -148,7 +151,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
// the identifiers in the certificate are extracted and compared against the (valid) Authorizations
// that are stored for the ACME Account. If these sets match, the Account is considered authorized
// to revoke the certificate. If this check fails, the client will receive an unauthorized error.
func (h *Handler) isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error {
func isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error {
if !account.IsValid() {
return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil)
}

@ -24,14 +24,16 @@ import (
"github.com/go-chi/chi"
"github.com/google/go-cmp/cmp"
"github.com/pkg/errors"
"golang.org/x/crypto/ocsp"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/x509util"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/x509util"
"golang.org/x/crypto/ocsp"
)
// v is a utility function to return the pointer to an integer
@ -274,14 +276,22 @@ func jwsFinal(sha crypto.Hash, sig []byte, phead, payload string) ([]byte, error
}
type mockCA struct {
MockIsRevoked func(sn string) (bool, error)
MockRevoke func(ctx context.Context, opts *authority.RevokeOptions) error
MockIsRevoked func(sn string) (bool, error)
MockRevoke func(ctx context.Context, opts *authority.RevokeOptions) error
MockAreSANsallowed func(ctx context.Context, sans []string) error
}
func (m *mockCA) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
return nil, nil
}
func (m *mockCA) AreSANsAllowed(ctx context.Context, sans []string) error {
if m.MockAreSANsallowed != nil {
return m.MockAreSANsallowed(ctx, sans)
}
return nil
}
func (m *mockCA) IsRevoked(sn string) (bool, error) {
if m.MockIsRevoked != nil {
return m.MockIsRevoked(sn)
@ -511,6 +521,7 @@ func TestHandler_RevokeCert(t *testing.T) {
"fail/no-jws": func(t *testing.T) test {
ctx := context.Background()
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"),
@ -519,6 +530,7 @@ func TestHandler_RevokeCert(t *testing.T) {
"fail/nil-jws": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), jwsContextKey, nil)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"),
@ -527,6 +539,7 @@ func TestHandler_RevokeCert(t *testing.T) {
"fail/no-provisioner": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"),
@ -534,8 +547,9 @@ func TestHandler_RevokeCert(t *testing.T) {
},
"fail/nil-provisioner": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
ctx = context.WithValue(ctx, provisionerContextKey, nil)
ctx = acme.NewProvisionerContext(ctx, nil)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"),
@ -543,8 +557,9 @@ func TestHandler_RevokeCert(t *testing.T) {
},
"fail/no-payload": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = acme.NewProvisionerContext(ctx, prov)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("payload does not exist"),
@ -552,9 +567,10 @@ func TestHandler_RevokeCert(t *testing.T) {
},
"fail/nil-payload": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("payload does not exist"),
@ -563,9 +579,10 @@ func TestHandler_RevokeCert(t *testing.T) {
"fail/unmarshal-payload": func(t *testing.T) test {
malformedPayload := []byte(`{"payload":malformed?}`)
ctx := context.WithValue(context.Background(), jwsContextKey, jws)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: malformedPayload})
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("error unmarshaling payload"),
@ -577,10 +594,11 @@ func TestHandler_RevokeCert(t *testing.T) {
}
wronglyEncodedPayloadBytes, err := json.Marshal(wrongPayload)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: wronglyEncodedPayloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 400,
err: &acme.Error{
@ -596,10 +614,11 @@ func TestHandler_RevokeCert(t *testing.T) {
}
emptyPayloadBytes, err := json.Marshal(emptyPayload)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: emptyPayloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws)
return test{
db: &acme.MockDB{},
ctx: ctx,
statusCode: 400,
err: &acme.Error{
@ -610,7 +629,7 @@ func TestHandler_RevokeCert(t *testing.T) {
}
},
"fail/db.GetCertificateBySerial": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws)
db := &acme.MockDB{
@ -628,7 +647,7 @@ func TestHandler_RevokeCert(t *testing.T) {
"fail/different-certificate-contents": func(t *testing.T) test {
aDifferentCert, _, err := generateCertKeyPair()
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws)
db := &acme.MockDB{
@ -647,7 +666,7 @@ func TestHandler_RevokeCert(t *testing.T) {
}
},
"fail/no-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws)
db := &acme.MockDB{
@ -666,7 +685,7 @@ func TestHandler_RevokeCert(t *testing.T) {
}
},
"fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, accContextKey, nil)
@ -687,11 +706,10 @@ func TestHandler_RevokeCert(t *testing.T) {
},
"fail/account-not-valid": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusInvalid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
@ -717,11 +735,10 @@ func TestHandler_RevokeCert(t *testing.T) {
},
"fail/account-not-authorized": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
@ -771,10 +788,9 @@ func TestHandler_RevokeCert(t *testing.T) {
assert.FatalError(t, err)
unauthorizedPayloadBytes, err := json.Marshal(jwsPayload)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: unauthorizedPayloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
@ -798,11 +814,10 @@ func TestHandler_RevokeCert(t *testing.T) {
},
"fail/certificate-revoked-check-fails": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
@ -832,7 +847,7 @@ func TestHandler_RevokeCert(t *testing.T) {
},
"fail/certificate-already-revoked": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws)
@ -870,7 +885,7 @@ func TestHandler_RevokeCert(t *testing.T) {
invalidReasonCodePayloadBytes, err := json.Marshal(invalidReasonPayload)
assert.FatalError(t, err)
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: invalidReasonCodePayloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws)
@ -908,7 +923,7 @@ func TestHandler_RevokeCert(t *testing.T) {
},
}
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, mockACMEProv)
ctx := acme.NewProvisionerContext(context.Background(), mockACMEProv)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws)
@ -940,7 +955,7 @@ func TestHandler_RevokeCert(t *testing.T) {
},
"fail/ca.Revoke": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws)
@ -972,7 +987,7 @@ func TestHandler_RevokeCert(t *testing.T) {
},
"fail/ca.Revoke-already-revoked": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws)
@ -1003,11 +1018,10 @@ func TestHandler_RevokeCert(t *testing.T) {
},
"ok/using-account-key": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
@ -1031,10 +1045,9 @@ func TestHandler_RevokeCert(t *testing.T) {
assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(jwsBytes))
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
@ -1057,11 +1070,12 @@ func TestHandler_RevokeCert(t *testing.T) {
for name, setup := range tests {
tc := setup(t)
t.Run(name, func(t *testing.T) {
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db, ca: tc.ca}
ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
mockMustAuthority(t, tc.ca)
req := httptest.NewRequest("POST", revokeURL, nil)
req = req.WithContext(tc.ctx)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
h.RevokeCert(w, req)
RevokeCert(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -1198,8 +1212,8 @@ func TestHandler_isAccountAuthorized(t *testing.T) {
for name, setup := range tests {
tc := setup(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db}
acmeErr := h.isAccountAuthorized(tc.ctx, tc.existingCert, tc.certToBeRevoked, tc.account)
// h := &Handler{db: tc.db}
acmeErr := isAccountAuthorized(tc.ctx, tc.existingCert, tc.certToBeRevoked, tc.account)
expectError := tc.err != nil
gotError := acmeErr != nil

@ -14,7 +14,6 @@ import (
"fmt"
"io"
"net"
"net/http"
"net/url"
"reflect"
"strings"
@ -61,27 +60,28 @@ func (ch *Challenge) ToLog() (interface{}, error) {
// type using the DB interface.
// satisfactorily validated, the 'status' and 'validated' attributes are
// updated.
func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey) error {
// If already valid or invalid then return without performing validation.
if ch.Status != StatusPending {
return nil
}
switch ch.Type {
case HTTP01:
return http01Validate(ctx, ch, db, jwk, vo)
return http01Validate(ctx, ch, db, jwk)
case DNS01:
return dns01Validate(ctx, ch, db, jwk, vo)
return dns01Validate(ctx, ch, db, jwk)
case TLSALPN01:
return tlsalpn01Validate(ctx, ch, db, jwk, vo)
return tlsalpn01Validate(ctx, ch, db, jwk)
default:
return NewErrorISE("unexpected challenge type '%s'", ch.Type)
}
}
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
u := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
resp, err := vo.HTTPGet(u.String())
vc := MustClientFromContext(ctx)
resp, err := vc.Get(u.String())
if err != nil {
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
"error doing http GET for url %s", u))
@ -119,6 +119,17 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb
return nil
}
// http01ChallengeHost checks if a Challenge value is an IPv6 address
// and adds square brackets if that's the case, so that it can be used
// as a hostname. Returns the original Challenge value as the host to
// use in other cases.
func http01ChallengeHost(value string) string {
if ip := net.ParseIP(value); ip != nil && ip.To4() == nil {
value = "[" + value + "]"
}
return value
}
func tlsAlert(err error) uint8 {
var opErr *net.OpError
if errors.As(err, &opErr) {
@ -130,7 +141,7 @@ func tlsAlert(err error) uint8 {
return 0
}
func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
config := &tls.Config{
NextProtos: []string{"acme-tls/1"},
// https://tools.ietf.org/html/rfc8737#section-4
@ -143,7 +154,8 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
hostPort := net.JoinHostPort(ch.Value, "443")
conn, err := vo.TLSDial("tcp", hostPort, config)
vc := MustClientFromContext(ctx)
conn, err := vc.TLSDial("tcp", hostPort, config)
if err != nil {
// With Go 1.17+ tls.Dial fails if there's no overlap between configured
// client and server protocols. When this happens the connection is
@ -242,14 +254,15 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
"incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension"))
}
func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
// Normalize domain for wildcard DNS names
// This is done to avoid making TXT lookups for domains like
// _acme-challenge.*.example.com
// Instead perform txt lookup for _acme-challenge.example.com
domain := strings.TrimPrefix(ch.Value, "*.")
txtRecords, err := vo.LookupTxt("_acme-challenge." + domain)
vc := MustClientFromContext(ctx)
txtRecords, err := vc.LookupTxt("_acme-challenge." + domain)
if err != nil {
return storeError(ctx, db, ch, false, WrapError(ErrorDNSType, err,
"error looking up TXT records for domain %s", domain))
@ -365,14 +378,3 @@ func storeError(ctx context.Context, db DB, ch *Challenge, markInvalid bool, err
}
return nil
}
type httpGetter func(string) (*http.Response, error)
type lookupTxt func(string) ([]string, error)
type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error)
// ValidateChallengeOptions are ACME challenge validator functions.
type ValidateChallengeOptions struct {
HTTPGet httpGetter
LookupTxt lookupTxt
TLSDial tlsDialer
}

@ -13,6 +13,7 @@ import (
"encoding/asn1"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"io"
"math/big"
@ -23,11 +24,23 @@ import (
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"go.step.sm/crypto/jose"
"github.com/smallstep/assert"
)
type mockClient struct {
get func(url string) (*http.Response, error)
lookupTxt func(name string) ([]string, error)
tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error)
}
func (m *mockClient) Get(url string) (*http.Response, error) { return m.get(url) }
func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) }
func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
return m.tlsDial(network, addr, config)
}
func Test_storeError(t *testing.T) {
type test struct {
ch *Challenge
@ -228,7 +241,7 @@ func TestKeyAuthorization(t *testing.T) {
func TestChallenge_Validate(t *testing.T) {
type test struct {
ch *Challenge
vo *ValidateChallengeOptions
vc Client
jwk *jose.JSONWebKey
db DB
srv *httptest.Server
@ -272,8 +285,8 @@ func TestChallenge_Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) {
vc: &mockClient{
get: func(url string) (*http.Response, error) {
return nil, errors.New("force")
},
},
@ -308,8 +321,8 @@ func TestChallenge_Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) {
vc: &mockClient{
get: func(url string) (*http.Response, error) {
return nil, errors.New("force")
},
},
@ -343,8 +356,8 @@ func TestChallenge_Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
LookupTxt: func(url string) ([]string, error) {
vc: &mockClient{
lookupTxt: func(url string) ([]string, error) {
return nil, errors.New("force")
},
},
@ -380,8 +393,8 @@ func TestChallenge_Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
LookupTxt: func(url string) ([]string, error) {
vc: &mockClient{
lookupTxt: func(url string) ([]string, error) {
return nil, errors.New("force")
},
},
@ -415,8 +428,8 @@ func TestChallenge_Validate(t *testing.T) {
}
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
vc: &mockClient{
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return nil, errors.New("force")
},
},
@ -465,8 +478,8 @@ func TestChallenge_Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: tlsDial,
vc: &mockClient{
tlsDial: tlsDial,
},
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -492,7 +505,8 @@ func TestChallenge_Validate(t *testing.T) {
defer tc.srv.Close()
}
if err := tc.ch.Validate(context.Background(), tc.db, tc.jwk, tc.vo); err != nil {
ctx := NewClientContext(context.Background(), tc.vc)
if err := tc.ch.Validate(ctx, tc.db, tc.jwk); err != nil {
if assert.NotNil(t, tc.err) {
switch k := err.(type) {
case *Error:
@ -523,7 +537,7 @@ func (errReader) Close() error {
func TestHTTP01Validate(t *testing.T) {
type test struct {
vo *ValidateChallengeOptions
vc Client
ch *Challenge
jwk *jose.JSONWebKey
db DB
@ -540,8 +554,8 @@ func TestHTTP01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) {
vc: &mockClient{
get: func(url string) (*http.Response, error) {
return nil, errors.New("force")
},
},
@ -574,8 +588,8 @@ func TestHTTP01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) {
vc: &mockClient{
get: func(url string) (*http.Response, error) {
return nil, errors.New("force")
},
},
@ -607,8 +621,8 @@ func TestHTTP01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) {
vc: &mockClient{
get: func(url string) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusBadRequest,
Body: errReader(0),
@ -644,8 +658,8 @@ func TestHTTP01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) {
vc: &mockClient{
get: func(url string) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusBadRequest,
Body: errReader(0),
@ -680,8 +694,8 @@ func TestHTTP01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) {
vc: &mockClient{
get: func(url string) (*http.Response, error) {
return &http.Response{
Body: errReader(0),
}, nil
@ -703,8 +717,8 @@ func TestHTTP01Validate(t *testing.T) {
jwk.Key = "foo"
return test{
ch: ch,
vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) {
vc: &mockClient{
get: func(url string) (*http.Response, error) {
return &http.Response{
Body: io.NopCloser(bytes.NewBufferString("foo")),
}, nil
@ -729,8 +743,8 @@ func TestHTTP01Validate(t *testing.T) {
assert.FatalError(t, err)
return test{
ch: ch,
vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) {
vc: &mockClient{
get: func(url string) (*http.Response, error) {
return &http.Response{
Body: io.NopCloser(bytes.NewBufferString("foo")),
}, nil
@ -771,8 +785,8 @@ func TestHTTP01Validate(t *testing.T) {
assert.FatalError(t, err)
return test{
ch: ch,
vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) {
vc: &mockClient{
get: func(url string) (*http.Response, error) {
return &http.Response{
Body: io.NopCloser(bytes.NewBufferString("foo")),
}, nil
@ -814,8 +828,8 @@ func TestHTTP01Validate(t *testing.T) {
assert.FatalError(t, err)
return test{
ch: ch,
vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) {
vc: &mockClient{
get: func(url string) (*http.Response, error) {
return &http.Response{
Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)),
}, nil
@ -856,8 +870,8 @@ func TestHTTP01Validate(t *testing.T) {
assert.FatalError(t, err)
return test{
ch: ch,
vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) {
vc: &mockClient{
get: func(url string) (*http.Response, error) {
return &http.Response{
Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)),
}, nil
@ -886,7 +900,8 @@ func TestHTTP01Validate(t *testing.T) {
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
if err := http01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil {
ctx := NewClientContext(context.Background(), tc.vc)
if err := http01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil {
if assert.NotNil(t, tc.err) {
switch k := err.(type) {
case *Error:
@ -910,7 +925,7 @@ func TestDNS01Validate(t *testing.T) {
fulldomain := "*.zap.internal"
domain := strings.TrimPrefix(fulldomain, "*.")
type test struct {
vo *ValidateChallengeOptions
vc Client
ch *Challenge
jwk *jose.JSONWebKey
db DB
@ -927,8 +942,8 @@ func TestDNS01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
LookupTxt: func(url string) ([]string, error) {
vc: &mockClient{
lookupTxt: func(url string) ([]string, error) {
return nil, errors.New("force")
},
},
@ -962,8 +977,8 @@ func TestDNS01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
LookupTxt: func(url string) ([]string, error) {
vc: &mockClient{
lookupTxt: func(url string) ([]string, error) {
return nil, errors.New("force")
},
},
@ -1000,8 +1015,8 @@ func TestDNS01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
LookupTxt: func(url string) ([]string, error) {
vc: &mockClient{
lookupTxt: func(url string) ([]string, error) {
return []string{"foo"}, nil
},
},
@ -1025,8 +1040,8 @@ func TestDNS01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
LookupTxt: func(url string) ([]string, error) {
vc: &mockClient{
lookupTxt: func(url string) ([]string, error) {
return []string{"foo", "bar"}, nil
},
},
@ -1067,8 +1082,8 @@ func TestDNS01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
LookupTxt: func(url string) ([]string, error) {
vc: &mockClient{
lookupTxt: func(url string) ([]string, error) {
return []string{"foo", "bar"}, nil
},
},
@ -1110,8 +1125,8 @@ func TestDNS01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
LookupTxt: func(url string) ([]string, error) {
vc: &mockClient{
lookupTxt: func(url string) ([]string, error) {
return []string{"foo", expected}, nil
},
},
@ -1155,8 +1170,8 @@ func TestDNS01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
LookupTxt: func(url string) ([]string, error) {
vc: &mockClient{
lookupTxt: func(url string) ([]string, error) {
return []string{"foo", expected}, nil
},
},
@ -1185,7 +1200,8 @@ func TestDNS01Validate(t *testing.T) {
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
if err := dns01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil {
ctx := NewClientContext(context.Background(), tc.vc)
if err := dns01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil {
if assert.NotNil(t, tc.err) {
switch k := err.(type) {
case *Error:
@ -1205,6 +1221,8 @@ func TestDNS01Validate(t *testing.T) {
}
}
type tlsDialer func(network, addr string, config *tls.Config) (conn *tls.Conn, err error)
func newTestTLSALPNServer(validationCert *tls.Certificate) (*httptest.Server, tlsDialer) {
srv := httptest.NewUnstartedServer(http.NewServeMux())
@ -1308,7 +1326,7 @@ func TestTLSALPN01Validate(t *testing.T) {
}
}
type test struct {
vo *ValidateChallengeOptions
vc Client
ch *Challenge
jwk *jose.JSONWebKey
db DB
@ -1320,8 +1338,8 @@ func TestTLSALPN01Validate(t *testing.T) {
ch := makeTLSCh()
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
vc: &mockClient{
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return nil, errors.New("force")
},
},
@ -1350,8 +1368,8 @@ func TestTLSALPN01Validate(t *testing.T) {
ch := makeTLSCh()
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
vc: &mockClient{
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return nil, errors.New("force")
},
},
@ -1383,8 +1401,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: tlsDial,
vc: &mockClient{
tlsDial: tlsDial,
},
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1412,8 +1430,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
vc: &mockClient{
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.Client(&noopConn{}, config), nil
},
},
@ -1442,8 +1460,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
vc: &mockClient{
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.Client(&noopConn{}, config), nil
},
},
@ -1478,8 +1496,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
vc: &mockClient{
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config)
},
},
@ -1515,8 +1533,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
vc: &mockClient{
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config)
},
},
@ -1561,8 +1579,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: tlsDial,
vc: &mockClient{
tlsDial: tlsDial,
},
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1604,8 +1622,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: tlsDial,
vc: &mockClient{
tlsDial: tlsDial,
},
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1648,8 +1666,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: tlsDial,
vc: &mockClient{
tlsDial: tlsDial,
},
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1691,8 +1709,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: tlsDial,
vc: &mockClient{
tlsDial: tlsDial,
},
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1735,8 +1753,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: tlsDial,
vc: &mockClient{
tlsDial: tlsDial,
},
srv: srv,
jwk: jwk,
@ -1757,8 +1775,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: tlsDial,
vc: &mockClient{
tlsDial: tlsDial,
},
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1796,8 +1814,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: tlsDial,
vc: &mockClient{
tlsDial: tlsDial,
},
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1840,8 +1858,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: tlsDial,
vc: &mockClient{
tlsDial: tlsDial,
},
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1883,8 +1901,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: tlsDial,
vc: &mockClient{
tlsDial: tlsDial,
},
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1923,8 +1941,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: tlsDial,
vc: &mockClient{
tlsDial: tlsDial,
},
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1962,8 +1980,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: tlsDial,
vc: &mockClient{
tlsDial: tlsDial,
},
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2007,8 +2025,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: tlsDial,
vc: &mockClient{
tlsDial: tlsDial,
},
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2053,8 +2071,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: tlsDial,
vc: &mockClient{
tlsDial: tlsDial,
},
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2099,8 +2117,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: tlsDial,
vc: &mockClient{
tlsDial: tlsDial,
},
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2143,8 +2161,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: tlsDial,
vc: &mockClient{
tlsDial: tlsDial,
},
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2188,8 +2206,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: tlsDial,
vc: &mockClient{
tlsDial: tlsDial,
},
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2225,8 +2243,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{
ch: ch,
vo: &ValidateChallengeOptions{
TLSDial: tlsDial,
vc: &mockClient{
tlsDial: tlsDial,
},
db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2252,7 +2270,8 @@ func TestTLSALPN01Validate(t *testing.T) {
defer tc.srv.Close()
}
if err := tlsalpn01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil {
ctx := NewClientContext(context.Background(), tc.vc)
if err := tlsalpn01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil {
if assert.NotNil(t, tc.err) {
switch k := err.(type) {
case *Error:
@ -2350,3 +2369,34 @@ func Test_serverName(t *testing.T) {
})
}
}
func Test_http01ChallengeHost(t *testing.T) {
tests := []struct {
name string
value string
want string
}{
{
name: "dns",
value: "www.example.com",
want: "www.example.com",
},
{
name: "ipv4",
value: "127.0.0.1",
want: "127.0.0.1",
},
{
name: "ipv6",
value: "::1",
want: "[::1]",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := http01ChallengeHost(tt.value); got != tt.want {
t.Errorf("http01ChallengeHost() = %v, want %v", got, tt.want)
}
})
}
}

@ -0,0 +1,79 @@
package acme
import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
)
// Client is the interface used to verify ACME challenges.
type Client interface {
// Get issues an HTTP GET to the specified URL.
Get(url string) (*http.Response, error)
// LookupTXT returns the DNS TXT records for the given domain name.
LookupTxt(name string) ([]string, error)
// TLSDial connects to the given network address using net.Dialer and then
// initiates a TLS handshake, returning the resulting TLS connection.
TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error)
}
type clientKey struct{}
// NewClientContext adds the given client to the context.
func NewClientContext(ctx context.Context, c Client) context.Context {
return context.WithValue(ctx, clientKey{}, c)
}
// ClientFromContext returns the current client from the given context.
func ClientFromContext(ctx context.Context) (c Client, ok bool) {
c, ok = ctx.Value(clientKey{}).(Client)
return
}
// MustClientFromContext returns the current client from the given context. It will
// return a new instance of the client if it does not exist.
func MustClientFromContext(ctx context.Context) Client {
c, ok := ClientFromContext(ctx)
if !ok {
return NewClient()
}
return c
}
type client struct {
http *http.Client
dialer *net.Dialer
}
// NewClient returns an implementation of Client for verifying ACME challenges.
func NewClient() Client {
return &client{
http: &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
},
dialer: &net.Dialer{
Timeout: 30 * time.Second,
},
}
}
func (c *client) Get(url string) (*http.Response, error) {
return c.http.Get(url)
}
func (c *client) LookupTxt(name string) ([]string, error) {
return net.LookupTXT(name)
}
func (c *client) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.DialWithDialer(c.dialer, network, addr, config)
}

@ -9,27 +9,66 @@ import (
"github.com/smallstep/certificates/authority/provisioner"
)
// Clock that returns time in UTC rounded to seconds.
type Clock struct{}
// Now returns the UTC time rounded to seconds.
func (c *Clock) Now() time.Time {
return time.Now().UTC().Truncate(time.Second)
}
var clock Clock
// CertificateAuthority is the interface implemented by a CA authority.
type CertificateAuthority interface {
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
AreSANsAllowed(ctx context.Context, sans []string) error
IsRevoked(sn string) (bool, error)
Revoke(context.Context, *authority.RevokeOptions) error
LoadProvisionerByName(string) (provisioner.Interface, error)
}
// Clock that returns time in UTC rounded to seconds.
type Clock struct{}
// NewContext adds the given acme components to the context.
func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker) context.Context {
ctx = NewDatabaseContext(ctx, db)
ctx = NewClientContext(ctx, client)
ctx = NewLinkerContext(ctx, linker)
// Prerequisite checker is optional.
if fn != nil {
ctx = NewPrerequisitesCheckerContext(ctx, fn)
}
return ctx
}
// Now returns the UTC time rounded to seconds.
func (c *Clock) Now() time.Time {
return time.Now().UTC().Truncate(time.Second)
// PrerequisitesChecker is a function that checks if all prerequisites for
// serving ACME are met by the CA configuration.
type PrerequisitesChecker func(ctx context.Context) (bool, error)
// DefaultPrerequisitesChecker is the default PrerequisiteChecker and returns
// always true.
func DefaultPrerequisitesChecker(ctx context.Context) (bool, error) {
return true, nil
}
var clock Clock
type prerequisitesKey struct{}
// NewPrerequisitesCheckerContext adds the given PrerequisitesChecker to the
// context.
func NewPrerequisitesCheckerContext(ctx context.Context, fn PrerequisitesChecker) context.Context {
return context.WithValue(ctx, prerequisitesKey{}, fn)
}
// PrerequisitesCheckerFromContext returns the PrerequisitesChecker in the
// context.
func PrerequisitesCheckerFromContext(ctx context.Context) (PrerequisitesChecker, bool) {
fn, ok := ctx.Value(prerequisitesKey{}).(PrerequisitesChecker)
return fn, ok && fn != nil
}
// Provisioner is an interface that implements a subset of the provisioner.Interface --
// only those methods required by the ACME api/authority.
type Provisioner interface {
AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error
AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error)
AuthorizeRevoke(ctx context.Context, token string) error
GetID() string
@ -38,16 +77,40 @@ type Provisioner interface {
GetOptions() *provisioner.Options
}
type provisionerKey struct{}
// NewProvisionerContext adds the given provisioner to the context.
func NewProvisionerContext(ctx context.Context, v Provisioner) context.Context {
return context.WithValue(ctx, provisionerKey{}, v)
}
// ProvisionerFromContext returns the current provisioner from the given context.
func ProvisionerFromContext(ctx context.Context) (v Provisioner, ok bool) {
v, ok = ctx.Value(provisionerKey{}).(Provisioner)
return
}
// MustLinkerFromContext returns the current provisioner from the given context.
// It will panic if it's not in the context.
func MustProvisionerFromContext(ctx context.Context) Provisioner {
if v, ok := ProvisionerFromContext(ctx); !ok {
panic("acme provisioner is not the context")
} else {
return v
}
}
// MockProvisioner for testing
type MockProvisioner struct {
Mret1 interface{}
Merr error
MgetID func() string
MgetName func() string
MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
MauthorizeRevoke func(ctx context.Context, token string) error
MdefaultTLSCertDuration func() time.Duration
MgetOptions func() *provisioner.Options
Mret1 interface{}
Merr error
MgetID func() string
MgetName func() string
MauthorizeOrderIdentifier func(ctx context.Context, identifier provisioner.ACMEIdentifier) error
MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
MauthorizeRevoke func(ctx context.Context, token string) error
MdefaultTLSCertDuration func() time.Duration
MgetOptions func() *provisioner.Options
}
// GetName mock
@ -58,6 +121,14 @@ func (m *MockProvisioner) GetName() string {
return m.Mret1.(string)
}
// AuthorizeOrderIdentifiers mock
func (m *MockProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error {
if m.MauthorizeOrderIdentifier != nil {
return m.MauthorizeOrderIdentifier(ctx, identifier)
}
return m.Merr
}
// AuthorizeSign mock
func (m *MockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
if m.MauthorizeSign != nil {

@ -23,6 +23,7 @@ type DB interface {
GetExternalAccountKey(ctx context.Context, provisionerID, keyID string) (*ExternalAccountKey, error)
GetExternalAccountKeys(ctx context.Context, provisionerID, cursor string, limit int) ([]*ExternalAccountKey, string, error)
GetExternalAccountKeyByReference(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error)
GetExternalAccountKeyByAccountID(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error)
DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error
UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *ExternalAccountKey) error
@ -48,6 +49,29 @@ type DB interface {
UpdateOrder(ctx context.Context, o *Order) error
}
type dbKey struct{}
// NewDatabaseContext adds the given acme database to the context.
func NewDatabaseContext(ctx context.Context, db DB) context.Context {
return context.WithValue(ctx, dbKey{}, db)
}
// DatabaseFromContext returns the current acme database from the given context.
func DatabaseFromContext(ctx context.Context) (db DB, ok bool) {
db, ok = ctx.Value(dbKey{}).(DB)
return
}
// MustDatabaseFromContext returns the current database from the given context.
// It will panic if it's not in the context.
func MustDatabaseFromContext(ctx context.Context) DB {
if db, ok := DatabaseFromContext(ctx); !ok {
panic("acme database is not in the context")
} else {
return db
}
}
// MockDB is an implementation of the DB interface that should only be used as
// a mock in tests.
type MockDB struct {
@ -60,6 +84,7 @@ type MockDB struct {
MockGetExternalAccountKey func(ctx context.Context, provisionerID, keyID string) (*ExternalAccountKey, error)
MockGetExternalAccountKeys func(ctx context.Context, provisionerID, cursor string, limit int) ([]*ExternalAccountKey, string, error)
MockGetExternalAccountKeyByReference func(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error)
MockGetExternalAccountKeyByAccountID func(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error)
MockDeleteExternalAccountKey func(ctx context.Context, provisionerID, keyID string) error
MockUpdateExternalAccountKey func(ctx context.Context, provisionerID string, eak *ExternalAccountKey) error
@ -168,6 +193,16 @@ func (m *MockDB) GetExternalAccountKeyByReference(ctx context.Context, provision
return m.MockRet1.(*ExternalAccountKey), m.MockError
}
// GetExternalAccountKeyByAccountID mock
func (m *MockDB) GetExternalAccountKeyByAccountID(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error) {
if m.MockGetExternalAccountKeyByAccountID != nil {
return m.MockGetExternalAccountKeyByAccountID(ctx, provisionerID, accountID)
} else if m.MockError != nil {
return nil, m.MockError
}
return m.MockRet1.(*ExternalAccountKey), m.MockError
}
// DeleteExternalAccountKey mock
func (m *MockDB) DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error {
if m.MockDeleteExternalAccountKey != nil {

@ -8,6 +8,7 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
nosqlDB "github.com/smallstep/nosql"
)
@ -23,7 +24,7 @@ type dbExternalAccountKey struct {
ProvisionerID string `json:"provisionerID"`
Reference string `json:"reference"`
AccountID string `json:"accountID,omitempty"`
KeyBytes []byte `json:"key"`
HmacKey []byte `json:"key"`
CreatedAt time.Time `json:"createdAt"`
BoundAt time.Time `json:"boundAt"`
}
@ -72,7 +73,7 @@ func (db *DB) CreateExternalAccountKey(ctx context.Context, provisionerID, refer
ID: keyID,
ProvisionerID: provisionerID,
Reference: reference,
KeyBytes: random,
HmacKey: random,
CreatedAt: clock.Now(),
}
@ -99,7 +100,7 @@ func (db *DB) CreateExternalAccountKey(ctx context.Context, provisionerID, refer
ProvisionerID: dbeak.ProvisionerID,
Reference: dbeak.Reference,
AccountID: dbeak.AccountID,
KeyBytes: dbeak.KeyBytes,
HmacKey: dbeak.HmacKey,
CreatedAt: dbeak.CreatedAt,
BoundAt: dbeak.BoundAt,
}, nil
@ -124,7 +125,7 @@ func (db *DB) GetExternalAccountKey(ctx context.Context, provisionerID, keyID st
ProvisionerID: dbeak.ProvisionerID,
Reference: dbeak.Reference,
AccountID: dbeak.AccountID,
KeyBytes: dbeak.KeyBytes,
HmacKey: dbeak.HmacKey,
CreatedAt: dbeak.CreatedAt,
BoundAt: dbeak.BoundAt,
}, nil
@ -191,7 +192,7 @@ func (db *DB) GetExternalAccountKeys(ctx context.Context, provisionerID, cursor
}
keys = append(keys, &acme.ExternalAccountKey{
ID: eak.ID,
KeyBytes: eak.KeyBytes,
HmacKey: eak.HmacKey,
ProvisionerID: eak.ProvisionerID,
Reference: eak.Reference,
AccountID: eak.AccountID,
@ -226,6 +227,10 @@ func (db *DB) GetExternalAccountKeyByReference(ctx context.Context, provisionerI
return db.GetExternalAccountKey(ctx, provisionerID, dbExternalAccountKeyReference.ExternalAccountKeyID)
}
func (db *DB) GetExternalAccountKeyByAccountID(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
return nil, nil
}
func (db *DB) UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
externalAccountKeyMutex.Lock()
defer externalAccountKeyMutex.Unlock()
@ -252,7 +257,7 @@ func (db *DB) UpdateExternalAccountKey(ctx context.Context, provisionerID string
ProvisionerID: eak.ProvisionerID,
Reference: eak.Reference,
AccountID: eak.AccountID,
KeyBytes: eak.KeyBytes,
HmacKey: eak.HmacKey,
CreatedAt: eak.CreatedAt,
BoundAt: eak.BoundAt,
}

@ -8,6 +8,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
certdb "github.com/smallstep/certificates/db"
@ -32,7 +33,7 @@ func TestDB_getDBExternalAccountKey(t *testing.T) {
ProvisionerID: provID,
Reference: "ref",
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
b, err := json.Marshal(dbeak)
@ -108,7 +109,7 @@ func TestDB_getDBExternalAccountKey(t *testing.T) {
}
} else if assert.Nil(t, tc.err) {
assert.Equals(t, dbeak.ID, tc.dbeak.ID)
assert.Equals(t, dbeak.KeyBytes, tc.dbeak.KeyBytes)
assert.Equals(t, dbeak.HmacKey, tc.dbeak.HmacKey)
assert.Equals(t, dbeak.ProvisionerID, tc.dbeak.ProvisionerID)
assert.Equals(t, dbeak.Reference, tc.dbeak.Reference)
assert.Equals(t, dbeak.CreatedAt, tc.dbeak.CreatedAt)
@ -136,7 +137,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) {
ProvisionerID: provID,
Reference: "ref",
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
b, err := json.Marshal(dbeak)
@ -154,7 +155,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) {
ProvisionerID: provID,
Reference: "ref",
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
},
}
@ -179,7 +180,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) {
ProvisionerID: "aDifferentProvID",
Reference: "ref",
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
b, err := json.Marshal(dbeak)
@ -197,7 +198,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) {
ProvisionerID: provID,
Reference: "ref",
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
},
acmeErr: acme.NewError(acme.ErrorUnauthorizedType, "provisioner does not match provisioner for which the EAB key was created"),
@ -225,7 +226,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) {
}
} else if assert.Nil(t, tc.err) {
assert.Equals(t, eak.ID, tc.eak.ID)
assert.Equals(t, eak.KeyBytes, tc.eak.KeyBytes)
assert.Equals(t, eak.HmacKey, tc.eak.HmacKey)
assert.Equals(t, eak.ProvisionerID, tc.eak.ProvisionerID)
assert.Equals(t, eak.Reference, tc.eak.Reference)
assert.Equals(t, eak.CreatedAt, tc.eak.CreatedAt)
@ -255,7 +256,7 @@ func TestDB_GetExternalAccountKeyByReference(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
dbref := &dbExternalAccountKeyReference{
@ -288,7 +289,7 @@ func TestDB_GetExternalAccountKeyByReference(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
},
err: nil,
@ -392,7 +393,7 @@ func TestDB_GetExternalAccountKeyByReference(t *testing.T) {
assert.Equals(t, eak.AccountID, tc.eak.AccountID)
assert.Equals(t, eak.BoundAt, tc.eak.BoundAt)
assert.Equals(t, eak.CreatedAt, tc.eak.CreatedAt)
assert.Equals(t, eak.KeyBytes, tc.eak.KeyBytes)
assert.Equals(t, eak.HmacKey, tc.eak.HmacKey)
assert.Equals(t, eak.ProvisionerID, tc.eak.ProvisionerID)
assert.Equals(t, eak.Reference, tc.eak.Reference)
}
@ -420,7 +421,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
b1, err := json.Marshal(dbeak1)
@ -430,7 +431,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
b2, err := json.Marshal(dbeak2)
@ -440,7 +441,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
ProvisionerID: "aDifferentProvID",
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
b3, err := json.Marshal(dbeak3)
@ -513,7 +514,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
},
{
@ -521,7 +522,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
},
},
@ -598,7 +599,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
assert.Equals(t, "", nextCursor)
for i, eak := range eaks {
assert.Equals(t, eak.ID, tc.eaks[i].ID)
assert.Equals(t, eak.KeyBytes, tc.eaks[i].KeyBytes)
assert.Equals(t, eak.HmacKey, tc.eaks[i].HmacKey)
assert.Equals(t, eak.ProvisionerID, tc.eaks[i].ProvisionerID)
assert.Equals(t, eak.Reference, tc.eaks[i].Reference)
assert.Equals(t, eak.CreatedAt, tc.eaks[i].CreatedAt)
@ -627,7 +628,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
dbref := &dbExternalAccountKeyReference{
@ -707,7 +708,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
ProvisionerID: "aDifferentProvID",
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
b, err := json.Marshal(dbeak)
@ -730,7 +731,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
dbref := &dbExternalAccountKeyReference{
@ -780,7 +781,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
dbref := &dbExternalAccountKeyReference{
@ -830,7 +831,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
dbref := &dbExternalAccountKeyReference{
@ -953,7 +954,7 @@ func TestDB_CreateExternalAccountKey(t *testing.T) {
assert.Equals(t, string(key), dbeak.ID)
assert.Equals(t, eak.ProvisionerID, dbeak.ProvisionerID)
assert.Equals(t, eak.Reference, dbeak.Reference)
assert.Equals(t, 32, len(dbeak.KeyBytes))
assert.Equals(t, 32, len(dbeak.HmacKey))
assert.False(t, dbeak.CreatedAt.IsZero())
assert.Equals(t, dbeak.AccountID, eak.AccountID)
assert.True(t, dbeak.BoundAt.IsZero())
@ -1078,7 +1079,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
b, err := json.Marshal(dbeak)
@ -1096,7 +1097,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
return test{
@ -1120,7 +1121,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
assert.Equals(t, dbNew.AccountID, dbeak.AccountID)
assert.Equals(t, dbNew.CreatedAt, dbeak.CreatedAt)
assert.Equals(t, dbNew.BoundAt, dbeak.BoundAt)
assert.Equals(t, dbNew.KeyBytes, dbeak.KeyBytes)
assert.Equals(t, dbNew.HmacKey, dbeak.HmacKey)
return nu, true, nil
},
},
@ -1148,7 +1149,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
ProvisionerID: "aDifferentProvID",
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
b, err := json.Marshal(newDBEAK)
@ -1174,7 +1175,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
b, err := json.Marshal(newDBEAK)
@ -1200,7 +1201,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
b, err := json.Marshal(newDBEAK)
@ -1237,7 +1238,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
assert.Equals(t, dbeak.AccountID, tc.eak.AccountID)
assert.Equals(t, dbeak.CreatedAt, tc.eak.CreatedAt)
assert.Equals(t, dbeak.BoundAt, tc.eak.BoundAt)
assert.Equals(t, dbeak.KeyBytes, tc.eak.KeyBytes)
assert.Equals(t, dbeak.HmacKey, tc.eak.HmacKey)
}
})
}

@ -1,100 +1,19 @@
package api
package acme
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"github.com/smallstep/certificates/acme"
"github.com/go-chi/chi"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
)
// NewLinker returns a new Directory type.
func NewLinker(dns, prefix string) Linker {
_, _, err := net.SplitHostPort(dns)
if err != nil && strings.Contains(err.Error(), "too many colons in address") {
// this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334
// in case a port was appended to this wrong format, we try to extract the port, then check if it's
// still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of
// these cases, then the input dns is not changed.
lastIndex := strings.LastIndex(dns, ":")
hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:]
if ip := net.ParseIP(hostPart); ip != nil {
dns = "[" + hostPart + "]:" + portPart
} else if ip := net.ParseIP(dns); ip != nil {
dns = "[" + dns + "]"
}
}
return &linker{prefix: prefix, dns: dns}
}
// Linker interface for generating links for ACME resources.
type Linker interface {
GetLink(ctx context.Context, typ LinkType, inputs ...string) string
GetUnescapedPathSuffix(typ LinkType, provName string, inputs ...string) string
LinkOrder(ctx context.Context, o *acme.Order)
LinkAccount(ctx context.Context, o *acme.Account)
LinkChallenge(ctx context.Context, o *acme.Challenge, azID string)
LinkAuthorization(ctx context.Context, o *acme.Authorization)
LinkOrdersByAccountID(ctx context.Context, orders []string)
}
// linker generates ACME links.
type linker struct {
prefix string
dns string
}
func (l *linker) GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string {
switch typ {
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
return fmt.Sprintf("/%s/%s", provisionerName, typ)
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
case ChallengeLinkType:
return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
case OrdersByAccountLinkType:
return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
case FinalizeLinkType:
return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
default:
return ""
}
}
// GetLink is a helper for GetLinkExplicit
func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string {
var (
provName string
baseURL = baseURLFromContext(ctx)
u = url.URL{}
)
if p, err := provisionerFromContext(ctx); err == nil && p != nil {
provName = p.GetName()
}
// Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351
if baseURL != nil {
u = *baseURL
}
u.Path = l.GetUnescapedPathSuffix(typ, provName, inputs...)
// If no Scheme is set, then default to https.
if u.Scheme == "" {
u.Scheme = "https"
}
// If no Host is set, then use the default (first DNS attr in the ca.json).
if u.Host == "" {
u.Host = l.dns
}
u.Path = l.prefix + u.Path
return u.String()
}
// LinkType captures the link type.
type LinkType int
@ -160,8 +79,155 @@ func (l LinkType) String() string {
}
}
func GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string {
switch typ {
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
return fmt.Sprintf("/%s/%s", provisionerName, typ)
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
case ChallengeLinkType:
return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
case OrdersByAccountLinkType:
return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
case FinalizeLinkType:
return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
default:
return ""
}
}
// NewLinker returns a new Directory type.
func NewLinker(dns, prefix string) Linker {
_, _, err := net.SplitHostPort(dns)
if err != nil && strings.Contains(err.Error(), "too many colons in address") {
// this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334
// in case a port was appended to this wrong format, we try to extract the port, then check if it's
// still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of
// these cases, then the input dns is not changed.
lastIndex := strings.LastIndex(dns, ":")
hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:]
if ip := net.ParseIP(hostPart); ip != nil {
dns = "[" + hostPart + "]:" + portPart
} else if ip := net.ParseIP(dns); ip != nil {
dns = "[" + dns + "]"
}
}
return &linker{prefix: prefix, dns: dns}
}
// Linker interface for generating links for ACME resources.
type Linker interface {
GetLink(ctx context.Context, typ LinkType, inputs ...string) string
Middleware(http.Handler) http.Handler
LinkOrder(ctx context.Context, o *Order)
LinkAccount(ctx context.Context, o *Account)
LinkChallenge(ctx context.Context, o *Challenge, azID string)
LinkAuthorization(ctx context.Context, o *Authorization)
LinkOrdersByAccountID(ctx context.Context, orders []string)
}
type linkerKey struct{}
// NewLinkerContext adds the given linker to the context.
func NewLinkerContext(ctx context.Context, v Linker) context.Context {
return context.WithValue(ctx, linkerKey{}, v)
}
// LinkerFromContext returns the current linker from the given context.
func LinkerFromContext(ctx context.Context) (v Linker, ok bool) {
v, ok = ctx.Value(linkerKey{}).(Linker)
return
}
// MustLinkerFromContext returns the current linker from the given context. It
// will panic if it's not in the context.
func MustLinkerFromContext(ctx context.Context) Linker {
if v, ok := LinkerFromContext(ctx); !ok {
panic("acme linker is not the context")
} else {
return v
}
}
type baseURLKey struct{}
func newBaseURLContext(ctx context.Context, r *http.Request) context.Context {
var u *url.URL
if r.Host != "" {
u = &url.URL{Scheme: "https", Host: r.Host}
}
return context.WithValue(ctx, baseURLKey{}, u)
}
func baseURLFromContext(ctx context.Context) *url.URL {
if u, ok := ctx.Value(baseURLKey{}).(*url.URL); ok {
return u
}
return nil
}
// linker generates ACME links.
type linker struct {
prefix string
dns string
}
// Middleware gets the provisioner and current url from the request and sets
// them in the context so we can use the linker to create ACME links.
func (l *linker) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Add base url to the context.
ctx := newBaseURLContext(r.Context(), r)
// Add provisioner to the context.
nameEscaped := chi.URLParam(r, "provisionerID")
name, err := url.PathUnescape(nameEscaped)
if err != nil {
render.Error(w, WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
return
}
p, err := authority.MustFromContext(ctx).LoadProvisionerByName(name)
if err != nil {
render.Error(w, err)
return
}
acmeProv, ok := p.(*provisioner.ACME)
if !ok {
render.Error(w, NewError(ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
return
}
ctx = NewProvisionerContext(ctx, Provisioner(acmeProv))
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// GetLink is a helper for GetLinkExplicit.
func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string {
var name string
if p, ok := ProvisionerFromContext(ctx); ok {
name = p.GetName()
}
var u url.URL
if baseURL := baseURLFromContext(ctx); baseURL != nil {
u = *baseURL
}
if u.Scheme == "" {
u.Scheme = "https"
}
if u.Host == "" {
u.Host = l.dns
}
u.Path = l.prefix + GetUnescapedPathSuffix(typ, name, inputs...)
return u.String()
}
// LinkOrder sets the ACME links required by an ACME order.
func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) {
func (l *linker) LinkOrder(ctx context.Context, o *Order) {
o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs))
for i, azID := range o.AuthorizationIDs {
o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, azID)
@ -173,17 +239,17 @@ func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) {
}
// LinkAccount sets the ACME links required by an ACME account.
func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) {
func (l *linker) LinkAccount(ctx context.Context, acc *Account) {
acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, acc.ID)
}
// LinkChallenge sets the ACME links required by an ACME challenge.
func (l *linker) LinkChallenge(ctx context.Context, ch *acme.Challenge, azID string) {
func (l *linker) LinkChallenge(ctx context.Context, ch *Challenge, azID string) {
ch.URL = l.GetLink(ctx, ChallengeLinkType, azID, ch.ID)
}
// LinkAuthorization sets the ACME links required by an ACME authorization.
func (l *linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) {
func (l *linker) LinkAuthorization(ctx context.Context, az *Authorization) {
for _, ch := range az.Challenges {
l.LinkChallenge(ctx, ch, az.ID)
}

@ -1,21 +1,38 @@
package api
package acme
import (
"context"
"fmt"
"net/url"
"testing"
"time"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/provisioner"
)
func TestLinker_GetUnescapedPathSuffix(t *testing.T) {
dns := "ca.smallstep.com"
prefix := "acme"
linker := NewLinker(dns, prefix)
func mockProvisioner(t *testing.T) Provisioner {
t.Helper()
var defaultDisableRenewal = false
// Initialize provisioners
p := &provisioner.ACME{
Type: "ACME",
Name: "test@acme-<test>provisioner.com",
}
if err := p.Init(provisioner.Config{Claims: provisioner.Claims{
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DisableRenewal: &defaultDisableRenewal,
}}); err != nil {
fmt.Printf("%v", err)
}
return p
}
getPath := linker.GetUnescapedPathSuffix
func TestGetUnescapedPathSuffix(t *testing.T) {
getPath := GetUnescapedPathSuffix
assert.Equals(t, getPath(NewNonceLinkType, "{provisionerID}"), "/{provisionerID}/new-nonce")
assert.Equals(t, getPath(DirectoryLinkType, "{provisionerID}"), "/{provisionerID}/directory")
@ -32,9 +49,9 @@ func TestLinker_GetUnescapedPathSuffix(t *testing.T) {
}
func TestLinker_DNS(t *testing.T) {
prov := newProv()
prov := mockProvisioner(t)
escProvName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx := NewProvisionerContext(context.Background(), prov)
type test struct {
name string
dns string
@ -117,19 +134,19 @@ func TestLinker_GetLink(t *testing.T) {
linker := NewLinker(dns, prefix)
id := "1234"
prov := newProv()
prov := mockProvisioner(t)
escProvName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx := NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
// No provisioner and no BaseURL from request
assert.Equals(t, linker.GetLink(context.Background(), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", ""))
// Provisioner: yes, BaseURL: no
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerContextKey, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName))
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerKey{}, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName))
// Provisioner: no, BaseURL: yes
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLContextKey, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", ""))
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLKey{}, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", ""))
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
@ -163,37 +180,37 @@ func TestLinker_GetLink(t *testing.T) {
func TestLinker_LinkOrder(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prov := newProv()
prov := mockProvisioner(t)
provName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx := NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
oid := "orderID"
certID := "certID"
linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix)
type test struct {
o *acme.Order
validate func(o *acme.Order)
o *Order
validate func(o *Order)
}
var tests = map[string]test{
"no-authz-and-no-cert": {
o: &acme.Order{
o: &Order{
ID: oid,
},
validate: func(o *acme.Order) {
validate: func(o *Order) {
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
assert.Equals(t, o.AuthorizationURLs, []string{})
assert.Equals(t, o.CertificateURL, "")
},
},
"one-authz-and-cert": {
o: &acme.Order{
o: &Order{
ID: oid,
CertificateID: certID,
AuthorizationIDs: []string{"foo"},
},
validate: func(o *acme.Order) {
validate: func(o *Order) {
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
assert.Equals(t, o.AuthorizationURLs, []string{
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
@ -202,12 +219,12 @@ func TestLinker_LinkOrder(t *testing.T) {
},
},
"many-authz": {
o: &acme.Order{
o: &Order{
ID: oid,
CertificateID: certID,
AuthorizationIDs: []string{"foo", "bar", "zap"},
},
validate: func(o *acme.Order) {
validate: func(o *Order) {
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
assert.Equals(t, o.AuthorizationURLs, []string{
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
@ -228,24 +245,24 @@ func TestLinker_LinkOrder(t *testing.T) {
func TestLinker_LinkAccount(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prov := newProv()
prov := mockProvisioner(t)
provName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx := NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
accID := "accountID"
linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix)
type test struct {
a *acme.Account
validate func(o *acme.Account)
a *Account
validate func(o *Account)
}
var tests = map[string]test{
"ok": {
a: &acme.Account{
a: &Account{
ID: accID,
},
validate: func(a *acme.Account) {
validate: func(a *Account) {
assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID))
},
},
@ -260,25 +277,25 @@ func TestLinker_LinkAccount(t *testing.T) {
func TestLinker_LinkChallenge(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prov := newProv()
prov := mockProvisioner(t)
provName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx := NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
chID := "chID"
azID := "azID"
linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix)
type test struct {
ch *acme.Challenge
validate func(o *acme.Challenge)
ch *Challenge
validate func(o *Challenge)
}
var tests = map[string]test{
"ok": {
ch: &acme.Challenge{
ch: &Challenge{
ID: chID,
},
validate: func(ch *acme.Challenge) {
validate: func(ch *Challenge) {
assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, azID, ch.ID))
},
},
@ -293,10 +310,10 @@ func TestLinker_LinkChallenge(t *testing.T) {
func TestLinker_LinkAuthorization(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prov := newProv()
prov := mockProvisioner(t)
provName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx := NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
chID0 := "chID-0"
chID1 := "chID-1"
@ -305,20 +322,20 @@ func TestLinker_LinkAuthorization(t *testing.T) {
linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix)
type test struct {
az *acme.Authorization
validate func(o *acme.Authorization)
az *Authorization
validate func(o *Authorization)
}
var tests = map[string]test{
"ok": {
az: &acme.Authorization{
az: &Authorization{
ID: azID,
Challenges: []*acme.Challenge{
Challenges: []*Challenge{
{ID: chID0},
{ID: chID1},
{ID: chID2},
},
},
validate: func(az *acme.Authorization) {
validate: func(az *Authorization) {
assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0))
assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1))
assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2))
@ -335,10 +352,10 @@ func TestLinker_LinkAuthorization(t *testing.T) {
func TestLinker_LinkOrdersByAccountID(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prov := newProv()
prov := mockProvisioner(t)
provName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx := NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix)

@ -268,6 +268,7 @@ func TestOrder_UpdateStatus(t *testing.T) {
type mockSignAuth struct {
sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
areSANsAllowed func(ctx context.Context, sans []string) error
loadProvisionerByName func(string) (provisioner.Interface, error)
ret1, ret2 interface{}
err error
@ -282,6 +283,13 @@ func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.S
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
}
func (m *mockSignAuth) AreSANsAllowed(ctx context.Context, sans []string) error {
if m.areSANsAllowed != nil {
return m.areSANsAllowed(ctx, sans)
}
return m.err
}
func (m *mockSignAuth) LoadProvisionerByName(name string) (provisioner.Interface, error) {
if m.loadProvisionerByName != nil {
return m.loadProvisionerByName(name)

@ -35,7 +35,6 @@ type Authority interface {
SSHAuthority
// context specifies the Authorize[Sign|Revoke|etc.] method.
Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error)
AuthorizeSign(ott string) ([]provisioner.SignOption, error)
AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error)
GetTLSOptions() *config.TLSOptions
Root(shasum string) (*x509.Certificate, error)
@ -53,6 +52,11 @@ type Authority interface {
GetCertificateRevocationList() ([]byte, error)
}
// mustAuthority will be replaced on unit tests.
var mustAuthority = func(ctx context.Context) Authority {
return authority.MustFromContext(ctx)
}
// TimeDuration is an alias of provisioner.TimeDuration
type TimeDuration = provisioner.TimeDuration
@ -244,49 +248,54 @@ type caHandler struct {
Authority Authority
}
// New creates a new RouterHandler with the CA endpoints.
func New(auth Authority) RouterHandler {
return &caHandler{
Authority: auth,
}
// Route configures the http request router.
func (h *caHandler) Route(r Router) {
Route(r)
}
func (h *caHandler) Route(r Router) {
r.MethodFunc("GET", "/version", h.Version)
r.MethodFunc("GET", "/health", h.Health)
r.MethodFunc("GET", "/root/{sha}", h.Root)
r.MethodFunc("POST", "/sign", h.Sign)
r.MethodFunc("POST", "/renew", h.Renew)
r.MethodFunc("POST", "/rekey", h.Rekey)
r.MethodFunc("POST", "/revoke", h.Revoke)
r.MethodFunc("GET", "/crl", h.CRL)
r.MethodFunc("GET", "/provisioners", h.Provisioners)
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey)
r.MethodFunc("GET", "/roots", h.Roots)
r.MethodFunc("GET", "/roots.pem", h.RootsPEM)
r.MethodFunc("GET", "/federation", h.Federation)
// New creates a new RouterHandler with the CA endpoints.
//
// Deprecated: Use api.Route(r Router)
func New(auth Authority) RouterHandler {
return &caHandler{}
}
func Route(r Router) {
r.MethodFunc("GET", "/version", Version)
r.MethodFunc("GET", "/health", Health)
r.MethodFunc("GET", "/root/{sha}", Root)
r.MethodFunc("POST", "/sign", Sign)
r.MethodFunc("POST", "/renew", Renew)
r.MethodFunc("POST", "/rekey", Rekey)
r.MethodFunc("POST", "/revoke", Revoke)
r.MethodFunc("GET", "/crl", CRL)
r.MethodFunc("GET", "/provisioners", Provisioners)
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", ProvisionerKey)
r.MethodFunc("GET", "/roots", Roots)
r.MethodFunc("GET", "/roots.pem", RootsPEM)
r.MethodFunc("GET", "/federation", Federation)
// SSH CA
r.MethodFunc("POST", "/ssh/sign", h.SSHSign)
r.MethodFunc("POST", "/ssh/renew", h.SSHRenew)
r.MethodFunc("POST", "/ssh/revoke", h.SSHRevoke)
r.MethodFunc("POST", "/ssh/rekey", h.SSHRekey)
r.MethodFunc("GET", "/ssh/roots", h.SSHRoots)
r.MethodFunc("GET", "/ssh/federation", h.SSHFederation)
r.MethodFunc("POST", "/ssh/config", h.SSHConfig)
r.MethodFunc("POST", "/ssh/config/{type}", h.SSHConfig)
r.MethodFunc("POST", "/ssh/check-host", h.SSHCheckHost)
r.MethodFunc("GET", "/ssh/hosts", h.SSHGetHosts)
r.MethodFunc("POST", "/ssh/bastion", h.SSHBastion)
r.MethodFunc("POST", "/ssh/sign", SSHSign)
r.MethodFunc("POST", "/ssh/renew", SSHRenew)
r.MethodFunc("POST", "/ssh/revoke", SSHRevoke)
r.MethodFunc("POST", "/ssh/rekey", SSHRekey)
r.MethodFunc("GET", "/ssh/roots", SSHRoots)
r.MethodFunc("GET", "/ssh/federation", SSHFederation)
r.MethodFunc("POST", "/ssh/config", SSHConfig)
r.MethodFunc("POST", "/ssh/config/{type}", SSHConfig)
r.MethodFunc("POST", "/ssh/check-host", SSHCheckHost)
r.MethodFunc("GET", "/ssh/hosts", SSHGetHosts)
r.MethodFunc("POST", "/ssh/bastion", SSHBastion)
// For compatibility with old code:
r.MethodFunc("POST", "/re-sign", h.Renew)
r.MethodFunc("POST", "/sign-ssh", h.SSHSign)
r.MethodFunc("GET", "/ssh/get-hosts", h.SSHGetHosts)
r.MethodFunc("POST", "/re-sign", Renew)
r.MethodFunc("POST", "/sign-ssh", SSHSign)
r.MethodFunc("GET", "/ssh/get-hosts", SSHGetHosts)
}
// Version is an HTTP handler that returns the version of the server.
func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) {
v := h.Authority.Version()
func Version(w http.ResponseWriter, r *http.Request) {
v := mustAuthority(r.Context()).Version()
render.JSON(w, VersionResponse{
Version: v.Version,
RequireClientAuthentication: v.RequireClientAuthentication,
@ -294,17 +303,17 @@ func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) {
}
// Health is an HTTP handler that returns the status of the server.
func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) {
func Health(w http.ResponseWriter, r *http.Request) {
render.JSON(w, HealthResponse{Status: "ok"})
}
// Root is an HTTP handler that using the SHA256 from the URL, returns the root
// certificate for the given SHA256.
func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
func Root(w http.ResponseWriter, r *http.Request) {
sha := chi.URLParam(r, "sha")
sum := strings.ToLower(strings.ReplaceAll(sha, "-", ""))
// Load root certificate with the
cert, err := h.Authority.Root(sum)
cert, err := mustAuthority(r.Context()).Root(sum)
if err != nil {
render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
return
@ -322,18 +331,19 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
}
// Provisioners returns the list of provisioners configured in the authority.
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
func Provisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := ParseCursor(r)
if err != nil {
render.Error(w, err)
return
}
p, next, err := h.Authority.GetProvisioners(cursor, limit)
p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit)
if err != nil {
render.Error(w, errs.InternalServerErr(err))
return
}
render.JSON(w, &ProvisionersResponse{
Provisioners: p,
NextCursor: next,
@ -341,19 +351,20 @@ func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
}
// ProvisionerKey returns the encrypted key of a provisioner by it's key id.
func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
func ProvisionerKey(w http.ResponseWriter, r *http.Request) {
kid := chi.URLParam(r, "kid")
key, err := h.Authority.GetEncryptedKey(kid)
key, err := mustAuthority(r.Context()).GetEncryptedKey(kid)
if err != nil {
render.Error(w, errs.NotFoundErr(err))
return
}
render.JSON(w, &ProvisionerKeyResponse{key})
}
// Roots returns all the root certificates for the CA.
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
roots, err := h.Authority.GetRoots()
func Roots(w http.ResponseWriter, r *http.Request) {
roots, err := mustAuthority(r.Context()).GetRoots()
if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error getting roots"))
return
@ -370,8 +381,8 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
}
// RootsPEM returns all the root certificates for the CA in PEM format.
func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) {
roots, err := h.Authority.GetRoots()
func RootsPEM(w http.ResponseWriter, r *http.Request) {
roots, err := mustAuthority(r.Context()).GetRoots()
if err != nil {
render.Error(w, errs.InternalServerErr(err))
return
@ -393,8 +404,8 @@ func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) {
}
// Federation returns all the public certificates in the federation.
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
federated, err := h.Authority.GetFederation()
func Federation(w http.ResponseWriter, r *http.Request) {
federated, err := mustAuthority(r.Context()).GetFederation()
if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error getting federated roots"))
return

@ -171,10 +171,21 @@ func parseCertificateRequest(data string) *x509.CertificateRequest {
return csr
}
func mockMustAuthority(t *testing.T, a Authority) {
t.Helper()
fn := mustAuthority
t.Cleanup(func() {
mustAuthority = fn
})
mustAuthority = func(ctx context.Context) Authority {
return a
}
}
type mockAuthority struct {
ret1, ret2 interface{}
err error
authorizeSign func(ott string) ([]provisioner.SignOption, error)
authorize func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error)
getTLSOptions func() *authority.TLSOptions
root func(shasum string) (*x509.Certificate, error)
@ -207,12 +218,8 @@ func (m *mockAuthority) GetCertificateRevocationList() ([]byte, error) {
// TODO: remove once Authorize is deprecated.
func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
return m.AuthorizeSign(ott)
}
func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
if m.authorizeSign != nil {
return m.authorizeSign(ott)
if m.authorize != nil {
return m.authorize(ctx, ott)
}
return m.ret1.([]provisioner.SignOption), m.err
}
@ -793,11 +800,10 @@ func Test_caHandler_Route(t *testing.T) {
}
}
func Test_caHandler_Health(t *testing.T) {
func Test_Health(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/health", nil)
w := httptest.NewRecorder()
h := New(&mockAuthority{}).(*caHandler)
h.Health(w, req)
Health(w, req)
res := w.Result()
if res.StatusCode != 200 {
@ -815,7 +821,7 @@ func Test_caHandler_Health(t *testing.T) {
}
}
func Test_caHandler_Root(t *testing.T) {
func Test_Root(t *testing.T) {
tests := []struct {
name string
root *x509.Certificate
@ -836,9 +842,9 @@ func Test_caHandler_Root(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ret1: tt.root, err: tt.err}).(*caHandler)
mockMustAuthority(t, &mockAuthority{ret1: tt.root, err: tt.err})
w := httptest.NewRecorder()
h.Root(w, req)
Root(w, req)
res := w.Result()
if res.StatusCode != tt.statusCode {
@ -859,7 +865,7 @@ func Test_caHandler_Root(t *testing.T) {
}
}
func Test_caHandler_Sign(t *testing.T) {
func Test_Sign(t *testing.T) {
csr := parseCertificateRequest(csrPEM)
valid, err := json.Marshal(SignRequest{
CsrPEM: CertificateRequest{csr},
@ -900,18 +906,18 @@ func Test_caHandler_Sign(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{
mockMustAuthority(t, &mockAuthority{
ret1: tt.cert, ret2: tt.root, err: tt.signErr,
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
return tt.certAttrOpts, tt.autherr
},
getTLSOptions: func() *authority.TLSOptions {
return nil
},
}).(*caHandler)
})
req := httptest.NewRequest("POST", "http://example.com/sign", strings.NewReader(tt.input))
w := httptest.NewRecorder()
h.Sign(logging.NewResponseLogger(w), req)
Sign(logging.NewResponseLogger(w), req)
res := w.Result()
if res.StatusCode != tt.statusCode {
@ -932,7 +938,7 @@ func Test_caHandler_Sign(t *testing.T) {
}
}
func Test_caHandler_Renew(t *testing.T) {
func Test_Renew(t *testing.T) {
cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
}
@ -1022,7 +1028,7 @@ func Test_caHandler_Renew(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{
mockMustAuthority(t, &mockAuthority{
ret1: tt.cert, ret2: tt.root, err: tt.err,
authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) {
jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root})
@ -1043,12 +1049,12 @@ func Test_caHandler_Renew(t *testing.T) {
getTLSOptions: func() *authority.TLSOptions {
return nil
},
}).(*caHandler)
})
req := httptest.NewRequest("POST", "http://example.com/renew", nil)
req.TLS = tt.tls
req.Header = tt.header
w := httptest.NewRecorder()
h.Renew(logging.NewResponseLogger(w), req)
Renew(logging.NewResponseLogger(w), req)
res := w.Result()
defer res.Body.Close()
@ -1077,7 +1083,7 @@ func Test_caHandler_Renew(t *testing.T) {
}
}
func Test_caHandler_Rekey(t *testing.T) {
func Test_Rekey(t *testing.T) {
cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
}
@ -1108,16 +1114,16 @@ func Test_caHandler_Rekey(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{
mockMustAuthority(t, &mockAuthority{
ret1: tt.cert, ret2: tt.root, err: tt.err,
getTLSOptions: func() *authority.TLSOptions {
return nil
},
}).(*caHandler)
})
req := httptest.NewRequest("POST", "http://example.com/rekey", strings.NewReader(tt.input))
req.TLS = tt.tls
w := httptest.NewRecorder()
h.Rekey(logging.NewResponseLogger(w), req)
Rekey(logging.NewResponseLogger(w), req)
res := w.Result()
if res.StatusCode != tt.statusCode {
@ -1138,7 +1144,7 @@ func Test_caHandler_Rekey(t *testing.T) {
}
}
func Test_caHandler_Provisioners(t *testing.T) {
func Test_Provisioners(t *testing.T) {
type fields struct {
Authority Authority
}
@ -1204,10 +1210,8 @@ func Test_caHandler_Provisioners(t *testing.T) {
assert.FatalError(t, err)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := &caHandler{
Authority: tt.fields.Authority,
}
h.Provisioners(tt.args.w, tt.args.r)
mockMustAuthority(t, tt.fields.Authority)
Provisioners(tt.args.w, tt.args.r)
rec := tt.args.w.(*httptest.ResponseRecorder)
res := rec.Result()
@ -1242,7 +1246,7 @@ func Test_caHandler_Provisioners(t *testing.T) {
}
}
func Test_caHandler_ProvisionerKey(t *testing.T) {
func Test_ProvisionerKey(t *testing.T) {
type fields struct {
Authority Authority
}
@ -1274,10 +1278,8 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := &caHandler{
Authority: tt.fields.Authority,
}
h.ProvisionerKey(tt.args.w, tt.args.r)
mockMustAuthority(t, tt.fields.Authority)
ProvisionerKey(tt.args.w, tt.args.r)
rec := tt.args.w.(*httptest.ResponseRecorder)
res := rec.Result()
@ -1302,7 +1304,7 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
}
}
func Test_caHandler_Roots(t *testing.T) {
func Test_Roots(t *testing.T) {
cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
}
@ -1323,11 +1325,11 @@ func Test_caHandler_Roots(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err})
req := httptest.NewRequest("GET", "http://example.com/roots", nil)
req.TLS = tt.tls
w := httptest.NewRecorder()
h.Roots(w, req)
Roots(w, req)
res := w.Result()
if res.StatusCode != tt.statusCode {
@ -1364,10 +1366,10 @@ func Test_caHandler_RootsPEM(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ret1: tt.roots, err: tt.err}).(*caHandler)
mockMustAuthority(t, &mockAuthority{ret1: tt.roots, err: tt.err})
req := httptest.NewRequest("GET", "https://example.com/roots", nil)
w := httptest.NewRecorder()
h.RootsPEM(w, req)
RootsPEM(w, req)
res := w.Result()
if res.StatusCode != tt.statusCode {
@ -1388,7 +1390,7 @@ func Test_caHandler_RootsPEM(t *testing.T) {
}
}
func Test_caHandler_Federation(t *testing.T) {
func Test_Federation(t *testing.T) {
cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
}
@ -1409,11 +1411,11 @@ func Test_caHandler_Federation(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err})
req := httptest.NewRequest("GET", "http://example.com/federation", nil)
req.TLS = tt.tls
w := httptest.NewRecorder()
h.Federation(w, req)
Federation(w, req)
res := w.Result()
if res.StatusCode != tt.statusCode {

@ -3,16 +3,20 @@ package read
import (
"encoding/json"
"errors"
"io"
"net/http"
"strings"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/errs"
)
// JSON reads JSON from the request body and stores it in the value
// pointed by v.
// pointed to by v.
func JSON(r io.Reader, v interface{}) error {
if err := json.NewDecoder(r).Decode(v); err != nil {
return errs.BadRequestErr(err, "error decoding json")
@ -21,11 +25,42 @@ func JSON(r io.Reader, v interface{}) error {
}
// ProtoJSON reads JSON from the request body and stores it in the value
// pointed by v.
// pointed to by m.
func ProtoJSON(r io.Reader, m proto.Message) error {
data, err := io.ReadAll(r)
if err != nil {
return errs.BadRequestErr(err, "error reading request body")
}
return protojson.Unmarshal(data, m)
switch err := protojson.Unmarshal(data, m); {
case errors.Is(err, proto.Error):
return badProtoJSONError(err.Error())
default:
return err
}
}
// badProtoJSONError is an error type that is returned by ProtoJSON
// when a proto message cannot be unmarshaled. Usually this is caused
// by an error in the request body.
type badProtoJSONError string
// Error implements error for badProtoJSONError
func (e badProtoJSONError) Error() string {
return string(e)
}
// Render implements render.RenderableError for badProtoJSONError
func (e badProtoJSONError) Render(w http.ResponseWriter) {
v := struct {
Type string `json:"type"`
Detail string `json:"detail"`
Message string `json:"message"`
}{
Type: "badRequest",
Detail: "bad request",
// trim the proto prefix for the message
Message: strings.TrimSpace(strings.TrimPrefix(e.Error(), "proto:")),
}
render.JSONStatus(w, v, http.StatusBadRequest)
}

@ -1,10 +1,21 @@
package read
import (
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"testing/iotest"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/errs"
)
@ -44,3 +55,110 @@ func TestJSON(t *testing.T) {
})
}
}
func TestProtoJSON(t *testing.T) {
p := new(linkedca.Policy) // TODO(hs): can we use something different, so we don't need the import?
type args struct {
r io.Reader
m proto.Message
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "fail/io.ReadAll",
args: args{
r: iotest.ErrReader(errors.New("read error")),
m: p,
},
wantErr: true,
},
{
name: "fail/proto",
args: args{
r: strings.NewReader(`{?}`),
m: p,
},
wantErr: true,
},
{
name: "ok",
args: args{
r: strings.NewReader(`{"x509":{}}`),
m: p,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ProtoJSON(tt.args.r, tt.args.m)
if (err != nil) != tt.wantErr {
t.Errorf("ProtoJSON() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr {
switch err.(type) {
case badProtoJSONError:
assert.Contains(t, err.Error(), "syntax error")
case *errs.Error:
var ee *errs.Error
if errors.As(err, &ee) {
assert.Equal(t, http.StatusBadRequest, ee.Status)
}
}
return
}
assert.Equal(t, protoreflect.FullName("linkedca.Policy"), proto.MessageName(tt.args.m))
assert.True(t, proto.Equal(&linkedca.Policy{X509: &linkedca.X509Policy{}}, tt.args.m))
})
}
}
func Test_badProtoJSONError_Render(t *testing.T) {
tests := []struct {
name string
e badProtoJSONError
expected string
}{
{
name: "bad proto normal space",
e: badProtoJSONError("proto: syntax error (line 1:2): invalid value ?"),
expected: "syntax error (line 1:2): invalid value ?",
},
{
name: "bad proto non breaking space",
e: badProtoJSONError("proto: syntax error (line 1:2): invalid value ?"),
expected: "syntax error (line 1:2): invalid value ?",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
tt.e.Render(w)
res := w.Result()
defer res.Body.Close()
data, err := io.ReadAll(res.Body)
assert.NoError(t, err)
v := struct {
Type string `json:"type"`
Detail string `json:"detail"`
Message string `json:"message"`
}{}
assert.NoError(t, json.Unmarshal(data, &v))
assert.Equal(t, "badRequest", v.Type)
assert.Equal(t, "bad request", v.Detail)
assert.Equal(t, "syntax error (line 1:2): invalid value ?", v.Message)
})
}
}

@ -27,7 +27,7 @@ func (s *RekeyRequest) Validate() error {
}
// Rekey is similar to renew except that the certificate will be renewed with new key from csr.
func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
func Rekey(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
render.Error(w, errs.BadRequest("missing client certificate"))
return
@ -44,7 +44,8 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
return
}
certChain, err := h.Authority.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey)
a := mustAuthority(r.Context())
certChain, err := a.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey)
if err != nil {
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey"))
return
@ -60,6 +61,6 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
ServerPEM: certChainPEM[0],
CaPEM: caPEM,
CertChainPEM: certChainPEM,
TLSOptions: h.Authority.GetTLSOptions(),
TLSOptions: a.GetTLSOptions(),
}, http.StatusCreated)
}

@ -16,14 +16,15 @@ const (
// Renew uses the information of certificate in the TLS connection to create a
// new one.
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
cert, err := h.getPeerCertificate(r)
func Renew(w http.ResponseWriter, r *http.Request) {
cert, err := getPeerCertificate(r)
if err != nil {
render.Error(w, err)
return
}
certChain, err := h.Authority.Renew(cert)
a := mustAuthority(r.Context())
certChain, err := a.Renew(cert)
if err != nil {
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
return
@ -39,17 +40,18 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
ServerPEM: certChainPEM[0],
CaPEM: caPEM,
CertChainPEM: certChainPEM,
TLSOptions: h.Authority.GetTLSOptions(),
TLSOptions: a.GetTLSOptions(),
}, http.StatusCreated)
}
func (h *caHandler) getPeerCertificate(r *http.Request) (*x509.Certificate, error) {
func getPeerCertificate(r *http.Request) (*x509.Certificate, error) {
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
return r.TLS.PeerCertificates[0], nil
}
if s := r.Header.Get(authorizationHeader); s != "" {
if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 {
return h.Authority.AuthorizeRenewToken(r.Context(), parts[1])
ctx := r.Context()
return mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1])
}
}
return nil, errs.BadRequest("missing client certificate")

@ -1,7 +1,6 @@
package api
import (
"context"
"net/http"
"golang.org/x/crypto/ocsp"
@ -49,7 +48,7 @@ func (r *RevokeRequest) Validate() (err error) {
// NOTE: currently only Passive revocation is supported.
//
// TODO: Add CRL and OCSP support.
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
func Revoke(w http.ResponseWriter, r *http.Request) {
var body RevokeRequest
if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -68,12 +67,14 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
PassiveOnly: body.Passive,
}
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod)
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.RevokeMethod)
a := mustAuthority(ctx)
// A token indicates that we are using the api via a provisioner token,
// otherwise it is assumed that the certificate is revoking itself over mTLS.
if len(body.OTT) > 0 {
logOtt(w, body.OTT)
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
if _, err := a.Authorize(ctx, body.OTT); err != nil {
render.Error(w, errs.UnauthorizedErr(err))
return
}
@ -98,7 +99,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
opts.MTLS = true
}
if err := h.Authority.Revoke(ctx, opts); err != nil {
if err := a.Revoke(ctx, opts); err != nil {
render.Error(w, errs.ForbiddenErr(err, "error revoking certificate"))
return
}

@ -108,7 +108,7 @@ func Test_caHandler_Revoke(t *testing.T) {
input: string(input),
statusCode: http.StatusOK,
auth: &mockAuthority{
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
return nil, nil
},
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
@ -152,7 +152,7 @@ func Test_caHandler_Revoke(t *testing.T) {
statusCode: http.StatusOK,
tls: cs,
auth: &mockAuthority{
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
return nil, nil
},
revoke: func(ctx context.Context, ri *authority.RevokeOptions) error {
@ -187,7 +187,7 @@ func Test_caHandler_Revoke(t *testing.T) {
input: string(input),
statusCode: http.StatusInternalServerError,
auth: &mockAuthority{
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
return nil, nil
},
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
@ -209,7 +209,7 @@ func Test_caHandler_Revoke(t *testing.T) {
input: string(input),
statusCode: http.StatusForbidden,
auth: &mockAuthority{
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
return nil, nil
},
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
@ -223,13 +223,13 @@ func Test_caHandler_Revoke(t *testing.T) {
for name, _tc := range tests {
tc := _tc(t)
t.Run(name, func(t *testing.T) {
h := New(tc.auth).(*caHandler)
mockMustAuthority(t, tc.auth)
req := httptest.NewRequest("POST", "http://example.com/revoke", strings.NewReader(tc.input))
if tc.tls != nil {
req.TLS = tc.tls
}
w := httptest.NewRecorder()
h.Revoke(logging.NewResponseLogger(w), req)
Revoke(logging.NewResponseLogger(w), req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)

@ -49,7 +49,7 @@ type SignResponse struct {
// Sign is an HTTP handler that reads a certificate request and an
// one-time-token (ott) from the body and creates a new certificate with the
// information in the certificate request.
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
func Sign(w http.ResponseWriter, r *http.Request) {
var body SignRequest
if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -68,13 +68,17 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
TemplateData: body.TemplateData,
}
signOpts, err := h.Authority.AuthorizeSign(body.OTT)
ctx := r.Context()
a := mustAuthority(ctx)
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
signOpts, err := a.Authorize(ctx, body.OTT)
if err != nil {
render.Error(w, errs.UnauthorizedErr(err))
return
}
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
certChain, err := a.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error signing certificate"))
return
@ -89,6 +93,6 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
ServerPEM: certChainPEM[0],
CaPEM: caPEM,
CertChainPEM: certChainPEM,
TLSOptions: h.Authority.GetTLSOptions(),
TLSOptions: a.GetTLSOptions(),
}, http.StatusCreated)
}

@ -250,7 +250,7 @@ type SSHBastionResponse struct {
// SSHSign is an HTTP handler that reads an SignSSHRequest with a one-time-token
// (ott) from the body and creates a new SSH certificate with the information in
// the request.
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
func SSHSign(w http.ResponseWriter, r *http.Request) {
var body SSHSignRequest
if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -288,13 +288,16 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
}
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
ctx = provisioner.NewContextWithToken(ctx, body.OTT)
a := mustAuthority(ctx)
signOpts, err := a.Authorize(ctx, body.OTT)
if err != nil {
render.Error(w, errs.UnauthorizedErr(err))
return
}
cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...)
cert, err := a.SignSSH(ctx, publicKey, opts, signOpts...)
if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
return
@ -302,7 +305,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
var addUserCertificate *SSHCertificate
if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil {
addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert)
addUserCert, err := a.SignSSHAddUser(ctx, addUserPublicKey, cert)
if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
return
@ -315,7 +318,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
if cr := body.IdentityCSR.CertificateRequest; cr != nil {
ctx := authority.NewContextWithSkipTokenReuse(r.Context())
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
signOpts, err := a.Authorize(ctx, body.OTT)
if err != nil {
render.Error(w, errs.UnauthorizedErr(err))
return
@ -327,7 +330,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
NotAfter: time.Unix(int64(cert.ValidBefore), 0),
})
certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...)
certChain, err := a.Sign(cr, provisioner.SignOptions{}, signOpts...)
if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate"))
return
@ -344,8 +347,9 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
// SSHRoots is an HTTP handler that returns the SSH public keys for user and host
// certificates.
func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHRoots(r.Context())
func SSHRoots(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
keys, err := mustAuthority(ctx).GetSSHRoots(ctx)
if err != nil {
render.Error(w, errs.InternalServerErr(err))
return
@ -369,8 +373,9 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
// SSHFederation is an HTTP handler that returns the federated SSH public keys
// for user and host certificates.
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHFederation(r.Context())
func SSHFederation(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
keys, err := mustAuthority(ctx).GetSSHFederation(ctx)
if err != nil {
render.Error(w, errs.InternalServerErr(err))
return
@ -394,7 +399,7 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
// SSHConfig is an HTTP handler that returns rendered templates for ssh clients
// and servers.
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
func SSHConfig(w http.ResponseWriter, r *http.Request) {
var body SSHConfigRequest
if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -405,7 +410,8 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
return
}
ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data)
ctx := r.Context()
ts, err := mustAuthority(ctx).GetSSHConfig(ctx, body.Type, body.Data)
if err != nil {
render.Error(w, errs.InternalServerErr(err))
return
@ -426,7 +432,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
}
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
func SSHCheckHost(w http.ResponseWriter, r *http.Request) {
var body SSHCheckPrincipalRequest
if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -437,7 +443,8 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
return
}
exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token)
ctx := r.Context()
exists, err := mustAuthority(ctx).CheckSSHHost(ctx, body.Principal, body.Token)
if err != nil {
render.Error(w, errs.InternalServerErr(err))
return
@ -448,13 +455,14 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
}
// SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts.
func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
func SSHGetHosts(w http.ResponseWriter, r *http.Request) {
var cert *x509.Certificate
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
cert = r.TLS.PeerCertificates[0]
}
hosts, err := h.Authority.GetSSHHosts(r.Context(), cert)
ctx := r.Context()
hosts, err := mustAuthority(ctx).GetSSHHosts(ctx, cert)
if err != nil {
render.Error(w, errs.InternalServerErr(err))
return
@ -465,7 +473,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
}
// SSHBastion provides returns the bastion configured if any.
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
func SSHBastion(w http.ResponseWriter, r *http.Request) {
var body SSHBastionRequest
if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -476,7 +484,8 @@ func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
return
}
bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname)
ctx := r.Context()
bastion, err := mustAuthority(ctx).GetSSHBastion(ctx, body.User, body.Hostname)
if err != nil {
render.Error(w, errs.InternalServerErr(err))
return

@ -39,7 +39,7 @@ type SSHRekeyResponse struct {
// SSHRekey is an HTTP handler that reads an RekeySSHRequest with a one-time-token
// (ott) from the body and creates a new SSH certificate with the information in
// the request.
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
func SSHRekey(w http.ResponseWriter, r *http.Request) {
var body SSHRekeyRequest
if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -59,7 +59,10 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
}
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
ctx = provisioner.NewContextWithToken(ctx, body.OTT)
a := mustAuthority(ctx)
signOpts, err := a.Authorize(ctx, body.OTT)
if err != nil {
render.Error(w, errs.UnauthorizedErr(err))
return
@ -70,7 +73,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
return
}
newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...)
newCert, err := a.RekeySSH(ctx, oldCert, publicKey, signOpts...)
if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate"))
return
@ -80,7 +83,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
identity, err := renewIdentityCertificate(r, notBefore, notAfter)
if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
return

@ -37,7 +37,7 @@ type SSHRenewResponse struct {
// SSHRenew is an HTTP handler that reads an RenewSSHRequest with a one-time-token
// (ott) from the body and creates a new SSH certificate with the information in
// the request.
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
func SSHRenew(w http.ResponseWriter, r *http.Request) {
var body SSHRenewRequest
if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -51,7 +51,10 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
}
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod)
_, err := h.Authority.Authorize(ctx, body.OTT)
ctx = provisioner.NewContextWithToken(ctx, body.OTT)
a := mustAuthority(ctx)
_, err := a.Authorize(ctx, body.OTT)
if err != nil {
render.Error(w, errs.UnauthorizedErr(err))
return
@ -62,7 +65,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
return
}
newCert, err := h.Authority.RenewSSH(ctx, oldCert)
newCert, err := a.RenewSSH(ctx, oldCert)
if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate"))
return
@ -72,7 +75,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
identity, err := renewIdentityCertificate(r, notBefore, notAfter)
if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
return
@ -85,7 +88,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
}
// renewIdentityCertificate request the client TLS certificate if present. If notBefore and notAfter are passed the
func (h *caHandler) renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) {
func renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
return nil, nil
}
@ -105,7 +108,7 @@ func (h *caHandler) renewIdentityCertificate(r *http.Request, notBefore, notAfte
cert.NotAfter = notAfter
}
certChain, err := h.Authority.Renew(cert)
certChain, err := mustAuthority(r.Context()).Renew(cert)
if err != nil {
return nil, err
}

@ -48,7 +48,7 @@ func (r *SSHRevokeRequest) Validate() (err error) {
// Revoke supports handful of different methods that revoke a Certificate.
//
// NOTE: currently only Passive revocation is supported.
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
func SSHRevoke(w http.ResponseWriter, r *http.Request) {
var body SSHRevokeRequest
if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -68,16 +68,19 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
}
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRevokeMethod)
a := mustAuthority(ctx)
// A token indicates that we are using the api via a provisioner token,
// otherwise it is assumed that the certificate is revoking itself over mTLS.
logOtt(w, body.OTT)
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
if _, err := a.Authorize(ctx, body.OTT); err != nil {
render.Error(w, errs.UnauthorizedErr(err))
return
}
opts.OTT = body.OTT
if err := h.Authority.Revoke(ctx, opts); err != nil {
if err := a.Revoke(ctx, opts); err != nil {
render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate"))
return
}

@ -251,7 +251,7 @@ func TestSignSSHRequest_Validate(t *testing.T) {
}
}
func Test_caHandler_SSHSign(t *testing.T) {
func Test_SSHSign(t *testing.T) {
user, err := getSignedUserCertificate()
assert.FatalError(t, err)
host, err := getSignedHostCertificate()
@ -315,8 +315,8 @@ func Test_caHandler_SSHSign(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
mockMustAuthority(t, &mockAuthority{
authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
return []provisioner.SignOption{}, tt.authErr
},
signSSH: func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
@ -328,11 +328,11 @@ func Test_caHandler_SSHSign(t *testing.T) {
sign: func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
return tt.tlsSignCerts, tt.tlsSignErr
},
}).(*caHandler)
})
req := httptest.NewRequest("POST", "http://example.com/ssh/sign", bytes.NewReader(tt.req))
w := httptest.NewRecorder()
h.SSHSign(logging.NewResponseLogger(w), req)
SSHSign(logging.NewResponseLogger(w), req)
res := w.Result()
if res.StatusCode != tt.statusCode {
@ -353,7 +353,7 @@ func Test_caHandler_SSHSign(t *testing.T) {
}
}
func Test_caHandler_SSHRoots(t *testing.T) {
func Test_SSHRoots(t *testing.T) {
user, err := ssh.NewPublicKey(sshUserKey.Public())
assert.FatalError(t, err)
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
@ -378,15 +378,15 @@ func Test_caHandler_SSHRoots(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{
mockMustAuthority(t, &mockAuthority{
getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) {
return tt.keys, tt.keysErr
},
}).(*caHandler)
})
req := httptest.NewRequest("GET", "http://example.com/ssh/roots", http.NoBody)
w := httptest.NewRecorder()
h.SSHRoots(logging.NewResponseLogger(w), req)
SSHRoots(logging.NewResponseLogger(w), req)
res := w.Result()
if res.StatusCode != tt.statusCode {
@ -407,7 +407,7 @@ func Test_caHandler_SSHRoots(t *testing.T) {
}
}
func Test_caHandler_SSHFederation(t *testing.T) {
func Test_SSHFederation(t *testing.T) {
user, err := ssh.NewPublicKey(sshUserKey.Public())
assert.FatalError(t, err)
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
@ -432,15 +432,15 @@ func Test_caHandler_SSHFederation(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{
mockMustAuthority(t, &mockAuthority{
getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) {
return tt.keys, tt.keysErr
},
}).(*caHandler)
})
req := httptest.NewRequest("GET", "http://example.com/ssh/federation", http.NoBody)
w := httptest.NewRecorder()
h.SSHFederation(logging.NewResponseLogger(w), req)
SSHFederation(logging.NewResponseLogger(w), req)
res := w.Result()
if res.StatusCode != tt.statusCode {
@ -461,7 +461,7 @@ func Test_caHandler_SSHFederation(t *testing.T) {
}
}
func Test_caHandler_SSHConfig(t *testing.T) {
func Test_SSHConfig(t *testing.T) {
userOutput := []templates.Output{
{Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("UserKnownHostsFile /home/user/.step/ssh/known_hosts")},
{Name: "known_host.tpl", Type: templates.File, Comment: "#", Path: "ssh/known_host", Content: []byte("@cert-authority * ecdsa-sha2-nistp256 AAAA...=")},
@ -492,15 +492,15 @@ func Test_caHandler_SSHConfig(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{
mockMustAuthority(t, &mockAuthority{
getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
return tt.output, tt.err
},
}).(*caHandler)
})
req := httptest.NewRequest("GET", "http://example.com/ssh/config", strings.NewReader(tt.req))
w := httptest.NewRecorder()
h.SSHConfig(logging.NewResponseLogger(w), req)
SSHConfig(logging.NewResponseLogger(w), req)
res := w.Result()
if res.StatusCode != tt.statusCode {
@ -521,7 +521,7 @@ func Test_caHandler_SSHConfig(t *testing.T) {
}
}
func Test_caHandler_SSHCheckHost(t *testing.T) {
func Test_SSHCheckHost(t *testing.T) {
tests := []struct {
name string
req string
@ -539,15 +539,15 @@ func Test_caHandler_SSHCheckHost(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{
mockMustAuthority(t, &mockAuthority{
checkSSHHost: func(ctx context.Context, principal, token string) (bool, error) {
return tt.exists, tt.err
},
}).(*caHandler)
})
req := httptest.NewRequest("GET", "http://example.com/ssh/check-host", strings.NewReader(tt.req))
w := httptest.NewRecorder()
h.SSHCheckHost(logging.NewResponseLogger(w), req)
SSHCheckHost(logging.NewResponseLogger(w), req)
res := w.Result()
if res.StatusCode != tt.statusCode {
@ -568,7 +568,7 @@ func Test_caHandler_SSHCheckHost(t *testing.T) {
}
}
func Test_caHandler_SSHGetHosts(t *testing.T) {
func Test_SSHGetHosts(t *testing.T) {
hosts := []authority.Host{
{HostID: "1", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}}, Hostname: "host1"},
{HostID: "2", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}, {ID: "2", Name: "group", Value: "2"}}, Hostname: "host2"},
@ -590,15 +590,15 @@ func Test_caHandler_SSHGetHosts(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{
mockMustAuthority(t, &mockAuthority{
getSSHHosts: func(context.Context, *x509.Certificate) ([]authority.Host, error) {
return tt.hosts, tt.err
},
}).(*caHandler)
})
req := httptest.NewRequest("GET", "http://example.com/ssh/host", http.NoBody)
w := httptest.NewRecorder()
h.SSHGetHosts(logging.NewResponseLogger(w), req)
SSHGetHosts(logging.NewResponseLogger(w), req)
res := w.Result()
if res.StatusCode != tt.statusCode {
@ -619,7 +619,7 @@ func Test_caHandler_SSHGetHosts(t *testing.T) {
}
}
func Test_caHandler_SSHBastion(t *testing.T) {
func Test_SSHBastion(t *testing.T) {
bastion := &authority.Bastion{
Hostname: "bastion.local",
}
@ -645,15 +645,15 @@ func Test_caHandler_SSHBastion(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{
mockMustAuthority(t, &mockAuthority{
getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) {
return tt.bastion, tt.bastionErr
},
}).(*caHandler)
})
req := httptest.NewRequest("POST", "http://example.com/ssh/bastion", bytes.NewReader(tt.req))
w := httptest.NewRecorder()
h.SSHBastion(logging.NewResponseLogger(w), req)
SSHBastion(logging.NewResponseLogger(w), req)
res := w.Result()
if res.StatusCode != tt.statusCode {

@ -1,22 +1,15 @@
package api
import (
"context"
"fmt"
"net/http"
"github.com/go-chi/chi"
"go.step.sm/linkedca"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner"
)
const (
// provisionerContextKey provisioner key
provisionerContextKey = ContextKey("provisioner")
)
// CreateExternalAccountKeyRequest is the type for POST /admin/acme/eab requests
@ -40,78 +33,121 @@ type GetExternalAccountKeysResponse struct {
// requireEABEnabled is a middleware that ensures ACME EAB is enabled
// before serving requests that act on ACME EAB credentials.
func (h *Handler) requireEABEnabled(next nextHTTP) nextHTTP {
func requireEABEnabled(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
provName := chi.URLParam(r, "provisionerName")
eabEnabled, prov, err := h.provisionerHasEABEnabled(ctx, provName)
if err != nil {
render.Error(w, err)
prov := linkedca.MustProvisionerFromContext(ctx)
acmeProvisioner := prov.GetDetails().GetACME()
if acmeProvisioner == nil {
render.Error(w, admin.NewErrorISE("error getting ACME details for provisioner '%s'", prov.GetName()))
return
}
if !eabEnabled {
render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName()))
if !acmeProvisioner.RequireEab {
render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner '%s'", prov.GetName()))
return
}
ctx = context.WithValue(ctx, provisionerContextKey, prov)
next(w, r.WithContext(ctx))
}
}
// provisionerHasEABEnabled determines if the "requireEAB" setting for an ACME
// provisioner is set to true and thus has EAB enabled.
func (h *Handler) provisionerHasEABEnabled(ctx context.Context, provisionerName string) (bool, *linkedca.Provisioner, error) {
var (
p provisioner.Interface
err error
)
if p, err = h.auth.LoadProvisionerByName(provisionerName); err != nil {
return false, nil, admin.WrapErrorISE(err, "error loading provisioner %s", provisionerName)
}
prov, err := h.adminDB.GetProvisioner(ctx, p.GetID())
if err != nil {
return false, nil, admin.WrapErrorISE(err, "error getting provisioner with ID: %s", p.GetID())
next(w, r)
}
details := prov.GetDetails()
if details == nil {
return false, nil, admin.NewErrorISE("error getting details for provisioner with ID: %s", p.GetID())
}
acmeProvisioner := details.GetACME()
if acmeProvisioner == nil {
return false, nil, admin.NewErrorISE("error getting ACME details for provisioner with ID: %s", p.GetID())
}
return acmeProvisioner.GetRequireEab(), prov, nil
}
type acmeAdminResponderInterface interface {
// ACMEAdminResponder is responsible for writing ACME admin responses
type ACMEAdminResponder interface {
GetExternalAccountKeys(w http.ResponseWriter, r *http.Request)
CreateExternalAccountKey(w http.ResponseWriter, r *http.Request)
DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request)
}
// ACMEAdminResponder is responsible for writing ACME admin responses
type ACMEAdminResponder struct{}
// acmeAdminResponder implements ACMEAdminResponder.
type acmeAdminResponder struct{}
// NewACMEAdminResponder returns a new ACMEAdminResponder
func NewACMEAdminResponder() *ACMEAdminResponder {
return &ACMEAdminResponder{}
func NewACMEAdminResponder() ACMEAdminResponder {
return &acmeAdminResponder{}
}
// GetExternalAccountKeys writes the response for the EAB keys GET endpoint
func (h *ACMEAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) {
func (h *acmeAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
}
// CreateExternalAccountKey writes the response for the EAB key POST endpoint
func (h *ACMEAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) {
func (h *acmeAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
}
// DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint
func (h *ACMEAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) {
func (h *acmeAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
}
func eakToLinked(k *acme.ExternalAccountKey) *linkedca.EABKey {
if k == nil {
return nil
}
eak := &linkedca.EABKey{
Id: k.ID,
HmacKey: k.HmacKey,
Provisioner: k.ProvisionerID,
Reference: k.Reference,
Account: k.AccountID,
CreatedAt: timestamppb.New(k.CreatedAt),
BoundAt: timestamppb.New(k.BoundAt),
}
if k.Policy != nil {
eak.Policy = &linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{},
Deny: &linkedca.X509Names{},
},
}
eak.Policy.X509.Allow.Dns = k.Policy.X509.Allowed.DNSNames
eak.Policy.X509.Allow.Ips = k.Policy.X509.Allowed.IPRanges
eak.Policy.X509.Deny.Dns = k.Policy.X509.Denied.DNSNames
eak.Policy.X509.Deny.Ips = k.Policy.X509.Denied.IPRanges
eak.Policy.X509.AllowWildcardNames = k.Policy.X509.AllowWildcardNames
}
return eak
}
func linkedEAKToCertificates(k *linkedca.EABKey) *acme.ExternalAccountKey {
if k == nil {
return nil
}
eak := &acme.ExternalAccountKey{
ID: k.Id,
ProvisionerID: k.Provisioner,
Reference: k.Reference,
AccountID: k.Account,
HmacKey: k.HmacKey,
CreatedAt: k.CreatedAt.AsTime(),
BoundAt: k.BoundAt.AsTime(),
}
if policy := k.GetPolicy(); policy != nil {
eak.Policy = &acme.Policy{}
if x509 := policy.GetX509(); x509 != nil {
eak.Policy.X509 = acme.X509Policy{}
if allow := x509.GetAllow(); allow != nil {
eak.Policy.X509.Allowed = acme.PolicyNames{}
eak.Policy.X509.Allowed.DNSNames = allow.Dns
eak.Policy.X509.Allowed.IPRanges = allow.Ips
}
if deny := x509.GetDeny(); deny != nil {
eak.Policy.X509.Denied = acme.PolicyNames{}
eak.Policy.X509.Denied.DNSNames = deny.Dns
eak.Policy.X509.Denied.IPRanges = deny.Ips
}
eak.Policy.X509.AllowWildcardNames = x509.AllowWildcardNames
}
}
return eak
}

@ -4,20 +4,24 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"time"
"github.com/go-chi/chi"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/linkedca"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
"go.step.sm/linkedca"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/admin"
)
func readProtoJSON(r io.ReadCloser, m proto.Message) error {
@ -29,109 +33,90 @@ func readProtoJSON(r io.ReadCloser, m proto.Message) error {
return protojson.Unmarshal(data, m)
}
func mockMustAuthority(t *testing.T, a adminAuthority) {
t.Helper()
fn := mustAuthority
t.Cleanup(func() {
mustAuthority = fn
})
mustAuthority = func(ctx context.Context) adminAuthority {
return a
}
}
func TestHandler_requireEABEnabled(t *testing.T) {
type test struct {
ctx context.Context
adminDB admin.DB
auth adminAuthority
next nextHTTP
next http.HandlerFunc
err *admin.Error
statusCode int
}
var tests = map[string]func(t *testing.T) test{
"fail/h.provisionerHasEABEnabled": func(t *testing.T) test {
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("provisionerName", "provName")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return nil, errors.New("force")
},
"fail/prov.GetDetails": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
Name: "provName",
}
err := admin.NewErrorISE("error loading provisioner provName: force")
err.Message = "error loading provisioner provName: force"
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
err := admin.NewErrorISE("error getting ACME details for provisioner 'provName'")
err.Message = "error getting ACME details for provisioner 'provName'"
return test{
ctx: ctx,
auth: auth,
err: err,
statusCode: 500,
}
},
"ok/eab-disabled": func(t *testing.T) test {
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("provisionerName", "provName")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return &provisioner.MockProvisioner{
MgetID: func() string {
return "provID"
},
}, nil
},
"fail/prov.GetDetails.GetACME": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
Name: "provName",
Details: &linkedca.ProvisionerDetails{},
}
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
err := admin.NewErrorISE("error getting ACME details for provisioner 'provName'")
err.Message = "error getting ACME details for provisioner 'provName'"
return test{
ctx: ctx,
err: err,
statusCode: 500,
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return &linkedca.Provisioner{
Id: "provID",
Name: "provName",
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_ACME{
ACME: &linkedca.ACMEProvisioner{
RequireEab: false,
},
},
},
"ok/eab-disabled": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
Name: "provName",
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_ACME{
ACME: &linkedca.ACMEProvisioner{
RequireEab: false,
},
}, nil
},
},
}
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
err := admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner provName")
err.Message = "ACME EAB not enabled for provisioner provName"
err.Message = "ACME EAB not enabled for provisioner 'provName'"
return test{
ctx: ctx,
auth: auth,
adminDB: db,
err: err,
statusCode: 400,
}
},
"ok/eab-enabled": func(t *testing.T) test {
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("provisionerName", "provName")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return &provisioner.MockProvisioner{
MgetID: func() string {
return "provID"
prov := &linkedca.Provisioner{
Id: "provID",
Name: "provName",
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_ACME{
ACME: &linkedca.ACMEProvisioner{
RequireEab: true,
},
}, nil
},
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return &linkedca.Provisioner{
Id: "provID",
Name: "provName",
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_ACME{
ACME: &linkedca.ACMEProvisioner{
RequireEab: true,
},
},
},
}, nil
},
},
}
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
return test{
ctx: ctx,
auth: auth,
adminDB: db,
ctx: ctx,
next: func(w http.ResponseWriter, r *http.Request) {
w.Write(nil) // mock response with status 200
},
@ -143,16 +128,9 @@ func TestHandler_requireEABEnabled(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
adminDB: tc.adminDB,
acmeDB: nil,
}
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
req = req.WithContext(tc.ctx)
req := httptest.NewRequest("GET", "/foo", nil).WithContext(tc.ctx)
w := httptest.NewRecorder()
h.requireEABEnabled(tc.next)(w, req)
requireEABEnabled(tc.next)(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
@ -176,216 +154,6 @@ func TestHandler_requireEABEnabled(t *testing.T) {
}
}
func TestHandler_provisionerHasEABEnabled(t *testing.T) {
type test struct {
adminDB admin.DB
auth adminAuthority
provisionerName string
want bool
err *admin.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/auth.LoadProvisionerByName": func(t *testing.T) test {
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return nil, errors.New("force")
},
}
return test{
auth: auth,
provisionerName: "provName",
want: false,
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
}
},
"fail/db.GetProvisioner": func(t *testing.T) test {
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return &provisioner.MockProvisioner{
MgetID: func() string {
return "provID"
},
}, nil
},
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return nil, errors.New("force")
},
}
return test{
auth: auth,
adminDB: db,
provisionerName: "provName",
want: false,
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
}
},
"fail/prov.GetDetails": func(t *testing.T) test {
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return &provisioner.MockProvisioner{
MgetID: func() string {
return "provID"
},
}, nil
},
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return &linkedca.Provisioner{
Id: "provID",
Name: "provName",
Details: nil,
}, nil
},
}
return test{
auth: auth,
adminDB: db,
provisionerName: "provName",
want: false,
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
}
},
"fail/details.GetACME": func(t *testing.T) test {
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return &provisioner.MockProvisioner{
MgetID: func() string {
return "provID"
},
}, nil
},
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return &linkedca.Provisioner{
Id: "provID",
Name: "provName",
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_ACME{
ACME: nil,
},
},
}, nil
},
}
return test{
auth: auth,
adminDB: db,
provisionerName: "provName",
want: false,
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
}
},
"ok/eab-disabled": func(t *testing.T) test {
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "eab-disabled", name)
return &provisioner.MockProvisioner{
MgetID: func() string {
return "provID"
},
}, nil
},
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return &linkedca.Provisioner{
Id: "provID",
Name: "eab-disabled",
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_ACME{
ACME: &linkedca.ACMEProvisioner{
RequireEab: false,
},
},
},
}, nil
},
}
return test{
adminDB: db,
auth: auth,
provisionerName: "eab-disabled",
want: false,
}
},
"ok/eab-enabled": func(t *testing.T) test {
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "eab-enabled", name)
return &provisioner.MockProvisioner{
MgetID: func() string {
return "provID"
},
}, nil
},
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return &linkedca.Provisioner{
Id: "provID",
Name: "eab-enabled",
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_ACME{
ACME: &linkedca.ACMEProvisioner{
RequireEab: true,
},
},
},
}, nil
},
}
return test{
adminDB: db,
auth: auth,
provisionerName: "eab-enabled",
want: true,
}
},
}
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
adminDB: tc.adminDB,
acmeDB: nil,
}
got, prov, err := h.provisionerHasEABEnabled(context.TODO(), tc.provisionerName)
if (err != nil) != (tc.err != nil) {
t.Errorf("Handler.provisionerHasEABEnabled() error = %v, want err %v", err, tc.err)
return
}
if tc.err != nil {
assert.Type(t, &linkedca.Provisioner{}, prov)
assert.Type(t, &admin.Error{}, err)
adminError, _ := err.(*admin.Error)
assert.Equals(t, tc.err.Type, adminError.Type)
assert.Equals(t, tc.err.Status, adminError.Status)
assert.Equals(t, tc.err.StatusCode(), adminError.StatusCode())
assert.Equals(t, tc.err.Message, adminError.Message)
assert.Equals(t, tc.err.Detail, adminError.Detail)
return
}
if got != tc.want {
t.Errorf("Handler.provisionerHasEABEnabled() = %v, want %v", got, tc.want)
}
})
}
}
func TestCreateExternalAccountKeyRequest_Validate(t *testing.T) {
type fields struct {
Reference string
@ -585,3 +353,206 @@ func TestHandler_GetExternalAccountKeys(t *testing.T) {
})
}
}
func Test_eakToLinked(t *testing.T) {
tests := []struct {
name string
k *acme.ExternalAccountKey
want *linkedca.EABKey
}{
{
name: "no-key",
k: nil,
want: nil,
},
{
name: "no-policy",
k: &acme.ExternalAccountKey{
ID: "keyID",
ProvisionerID: "provID",
Reference: "ref",
AccountID: "accID",
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour),
BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC),
Policy: nil,
},
want: &linkedca.EABKey{
Id: "keyID",
Provisioner: "provID",
HmacKey: []byte{1, 3, 3, 7},
Reference: "ref",
Account: "accID",
CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)),
BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)),
Policy: nil,
},
},
{
name: "with-policy",
k: &acme.ExternalAccountKey{
ID: "keyID",
ProvisionerID: "provID",
Reference: "ref",
AccountID: "accID",
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour),
BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC),
Policy: &acme.Policy{
X509: acme.X509Policy{
Allowed: acme.PolicyNames{
DNSNames: []string{"*.local"},
IPRanges: []string{"10.0.0.0/24"},
},
Denied: acme.PolicyNames{
DNSNames: []string{"badhost.local"},
IPRanges: []string{"10.0.0.30"},
},
},
},
},
want: &linkedca.EABKey{
Id: "keyID",
Provisioner: "provID",
HmacKey: []byte{1, 3, 3, 7},
Reference: "ref",
Account: "accID",
CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)),
BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)),
Policy: &linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{
Dns: []string{"*.local"},
Ips: []string{"10.0.0.0/24"},
},
Deny: &linkedca.X509Names{
Dns: []string{"badhost.local"},
Ips: []string{"10.0.0.30"},
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := eakToLinked(tt.k); !reflect.DeepEqual(got, tt.want) {
t.Errorf("eakToLinked() = %v, want %v", got, tt.want)
}
})
}
}
func Test_linkedEAKToCertificates(t *testing.T) {
tests := []struct {
name string
k *linkedca.EABKey
want *acme.ExternalAccountKey
}{
{
name: "no-key",
k: nil,
want: nil,
},
{
name: "no-policy",
k: &linkedca.EABKey{
Id: "keyID",
Provisioner: "provID",
HmacKey: []byte{1, 3, 3, 7},
Reference: "ref",
Account: "accID",
CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)),
BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)),
Policy: nil,
},
want: &acme.ExternalAccountKey{
ID: "keyID",
ProvisionerID: "provID",
Reference: "ref",
AccountID: "accID",
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour),
BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC),
Policy: nil,
},
},
{
name: "no-x509-policy",
k: &linkedca.EABKey{
Id: "keyID",
Provisioner: "provID",
HmacKey: []byte{1, 3, 3, 7},
Reference: "ref",
Account: "accID",
CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)),
BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)),
Policy: &linkedca.Policy{},
},
want: &acme.ExternalAccountKey{
ID: "keyID",
ProvisionerID: "provID",
Reference: "ref",
AccountID: "accID",
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour),
BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC),
Policy: &acme.Policy{},
},
},
{
name: "with-x509-policy",
k: &linkedca.EABKey{
Id: "keyID",
Provisioner: "provID",
HmacKey: []byte{1, 3, 3, 7},
Reference: "ref",
Account: "accID",
CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)),
BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)),
Policy: &linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{
Dns: []string{"*.local"},
Ips: []string{"10.0.0.0/24"},
},
Deny: &linkedca.X509Names{
Dns: []string{"badhost.local"},
Ips: []string{"10.0.0.30"},
},
AllowWildcardNames: true,
},
},
},
want: &acme.ExternalAccountKey{
ID: "keyID",
ProvisionerID: "provID",
Reference: "ref",
AccountID: "accID",
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour),
BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC),
Policy: &acme.Policy{
X509: acme.X509Policy{
Allowed: acme.PolicyNames{
DNSNames: []string{"*.local"},
IPRanges: []string{"10.0.0.0/24"},
},
Denied: acme.PolicyNames{
DNSNames: []string{"badhost.local"},
IPRanges: []string{"10.0.0.30"},
},
AllowWildcardNames: true,
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := linkedEAKToCertificates(tt.k); !reflect.DeepEqual(got, tt.want) {
t.Errorf("linkedEAKToCertificates() = %v, want %v", got, tt.want)
}
})
}
}

@ -29,6 +29,10 @@ type adminAuthority interface {
LoadProvisionerByID(id string) (provisioner.Interface, error)
UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error
RemoveProvisioner(ctx context.Context, id string) error
GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error)
CreateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error)
UpdateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error)
RemoveAuthorityPolicy(ctx context.Context) error
}
// CreateAdminRequest represents the body for a CreateAdmin request.
@ -81,10 +85,10 @@ type DeleteResponse struct {
}
// GetAdmin returns the requested admin, or an error.
func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) {
func GetAdmin(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
adm, ok := h.auth.LoadAdminByID(id)
adm, ok := mustAuthority(r.Context()).LoadAdminByID(id)
if !ok {
render.Error(w, admin.NewError(admin.ErrorNotFoundType,
"admin %s not found", id))
@ -94,7 +98,7 @@ func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) {
}
// GetAdmins returns a segment of admins associated with the authority.
func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
func GetAdmins(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := api.ParseCursor(r)
if err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
@ -102,7 +106,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
return
}
admins, nextCursor, err := h.auth.GetAdmins(cursor, limit)
admins, nextCursor, err := mustAuthority(r.Context()).GetAdmins(cursor, limit)
if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins"))
return
@ -114,7 +118,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
}
// CreateAdmin creates a new admin.
func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
func CreateAdmin(w http.ResponseWriter, r *http.Request) {
var body CreateAdminRequest
if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
@ -126,7 +130,8 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
return
}
p, err := h.auth.LoadProvisionerByName(body.Provisioner)
auth := mustAuthority(r.Context())
p, err := auth.LoadProvisionerByName(body.Provisioner)
if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner))
return
@ -137,7 +142,7 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
Type: body.Type,
}
// Store to authority collection.
if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil {
if err := auth.StoreAdmin(r.Context(), adm, p); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error storing admin"))
return
}
@ -146,10 +151,10 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
}
// DeleteAdmin deletes admin.
func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
func DeleteAdmin(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
if err := h.auth.RemoveAdmin(r.Context(), id); err != nil {
if err := mustAuthority(r.Context()).RemoveAdmin(r.Context(), id); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id))
return
}
@ -158,7 +163,7 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
}
// UpdateAdmin updates an existing admin.
func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
func UpdateAdmin(w http.ResponseWriter, r *http.Request) {
var body UpdateAdminRequest
if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
@ -171,8 +176,8 @@ func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
}
id := chi.URLParam(r, "id")
adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type})
auth := mustAuthority(r.Context())
adm, err := auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type})
if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id))
return

@ -14,11 +14,13 @@ import (
"github.com/go-chi/chi"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/protobuf/types/known/timestamppb"
"go.step.sm/linkedca"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/linkedca"
"google.golang.org/protobuf/types/known/timestamppb"
)
type mockAdminAuthority struct {
@ -37,6 +39,11 @@ type mockAdminAuthority struct {
MockLoadProvisionerByID func(id string) (provisioner.Interface, error)
MockUpdateProvisioner func(ctx context.Context, nu *linkedca.Provisioner) error
MockRemoveProvisioner func(ctx context.Context, id string) error
MockGetAuthorityPolicy func(ctx context.Context) (*linkedca.Policy, error)
MockCreateAuthorityPolicy func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error)
MockUpdateAuthorityPolicy func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error)
MockRemoveAuthorityPolicy func(ctx context.Context) error
}
func (m *mockAdminAuthority) IsAdminAPIEnabled() bool {
@ -130,6 +137,34 @@ func (m *mockAdminAuthority) RemoveProvisioner(ctx context.Context, id string) e
return m.MockErr
}
func (m *mockAdminAuthority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
if m.MockGetAuthorityPolicy != nil {
return m.MockGetAuthorityPolicy(ctx)
}
return m.MockRet1.(*linkedca.Policy), m.MockErr
}
func (m *mockAdminAuthority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) {
if m.MockCreateAuthorityPolicy != nil {
return m.MockCreateAuthorityPolicy(ctx, adm, policy)
}
return m.MockRet1.(*linkedca.Policy), m.MockErr
}
func (m *mockAdminAuthority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) {
if m.MockUpdateAuthorityPolicy != nil {
return m.MockUpdateAuthorityPolicy(ctx, adm, policy)
}
return m.MockRet1.(*linkedca.Policy), m.MockErr
}
func (m *mockAdminAuthority) RemoveAuthorityPolicy(ctx context.Context) error {
if m.MockRemoveAuthorityPolicy != nil {
return m.MockRemoveAuthorityPolicy(ctx)
}
return m.MockErr
}
func TestCreateAdminRequest_Validate(t *testing.T) {
type fields struct {
Subject string
@ -317,14 +352,11 @@ func TestHandler_GetAdmin(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
}
mockMustAuthority(t, tc.auth)
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.GetAdmin(w, req)
GetAdmin(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
@ -456,13 +488,10 @@ func TestHandler_GetAdmins(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
}
mockMustAuthority(t, tc.auth)
req := tc.req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.GetAdmins(w, req)
GetAdmins(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
@ -640,13 +669,11 @@ func TestHandler_CreateAdmin(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
}
mockMustAuthority(t, tc.auth)
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.CreateAdmin(w, req)
CreateAdmin(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
@ -732,13 +759,11 @@ func TestHandler_DeleteAdmin(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
}
mockMustAuthority(t, tc.auth)
req := httptest.NewRequest("DELETE", "/foo", nil) // chi routing is prepared in test setup
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.DeleteAdmin(w, req)
DeleteAdmin(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
@ -877,13 +902,11 @@ func TestHandler_UpdateAdmin(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
}
mockMustAuthority(t, tc.auth)
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.UpdateAdmin(w, req)
UpdateAdmin(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)

@ -1,56 +1,117 @@
package api
import (
"context"
"net/http"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/admin"
)
// Handler is the Admin API request handler.
type Handler struct {
adminDB admin.DB
auth adminAuthority
acmeDB acme.DB
acmeResponder acmeAdminResponderInterface
acmeResponder ACMEAdminResponder
policyResponder PolicyAdminResponder
}
// Route traffic and implement the Router interface.
//
// Deprecated: use Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder)
func (h *Handler) Route(r api.Router) {
Route(r, h.acmeResponder, h.policyResponder)
}
// NewHandler returns a new Authority Config Handler.
func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder acmeAdminResponderInterface) api.RouterHandler {
//
// Deprecated: use Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder)
func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder) api.RouterHandler {
return &Handler{
auth: auth,
adminDB: adminDB,
acmeDB: acmeDB,
acmeResponder: acmeResponder,
acmeResponder: acmeResponder,
policyResponder: policyResponder,
}
}
var mustAuthority = func(ctx context.Context) adminAuthority {
return authority.MustFromContext(ctx)
}
// Route traffic and implement the Router interface.
func (h *Handler) Route(r api.Router) {
authnz := func(next nextHTTP) nextHTTP {
return h.extractAuthorizeTokenAdmin(h.requireAPIEnabled(next))
func Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder) {
authnz := func(next http.HandlerFunc) http.HandlerFunc {
return extractAuthorizeTokenAdmin(requireAPIEnabled(next))
}
enabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc {
return checkAction(next, true)
}
disabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc {
return checkAction(next, false)
}
requireEABEnabled := func(next nextHTTP) nextHTTP {
return h.requireEABEnabled(next)
acmeEABMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
return authnz(loadProvisionerByName(requireEABEnabled(next)))
}
authorityPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
return authnz(enabledInStandalone(next))
}
provisionerPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
return authnz(disabledInStandalone(loadProvisionerByName(next)))
}
acmePolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
return authnz(disabledInStandalone(loadProvisionerByName(requireEABEnabled(loadExternalAccountKey(next)))))
}
// Provisioners
r.MethodFunc("GET", "/provisioners/{name}", authnz(h.GetProvisioner))
r.MethodFunc("GET", "/provisioners", authnz(h.GetProvisioners))
r.MethodFunc("POST", "/provisioners", authnz(h.CreateProvisioner))
r.MethodFunc("PUT", "/provisioners/{name}", authnz(h.UpdateProvisioner))
r.MethodFunc("DELETE", "/provisioners/{name}", authnz(h.DeleteProvisioner))
r.MethodFunc("GET", "/provisioners/{name}", authnz(GetProvisioner))
r.MethodFunc("GET", "/provisioners", authnz(GetProvisioners))
r.MethodFunc("POST", "/provisioners", authnz(CreateProvisioner))
r.MethodFunc("PUT", "/provisioners/{name}", authnz(UpdateProvisioner))
r.MethodFunc("DELETE", "/provisioners/{name}", authnz(DeleteProvisioner))
// Admins
r.MethodFunc("GET", "/admins/{id}", authnz(h.GetAdmin))
r.MethodFunc("GET", "/admins", authnz(h.GetAdmins))
r.MethodFunc("POST", "/admins", authnz(h.CreateAdmin))
r.MethodFunc("PATCH", "/admins/{id}", authnz(h.UpdateAdmin))
r.MethodFunc("DELETE", "/admins/{id}", authnz(h.DeleteAdmin))
// ACME External Account Binding Keys
r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", authnz(requireEABEnabled(h.acmeResponder.GetExternalAccountKeys)))
r.MethodFunc("GET", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.GetExternalAccountKeys)))
r.MethodFunc("POST", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.CreateExternalAccountKey)))
r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", authnz(requireEABEnabled(h.acmeResponder.DeleteExternalAccountKey)))
r.MethodFunc("GET", "/admins/{id}", authnz(GetAdmin))
r.MethodFunc("GET", "/admins", authnz(GetAdmins))
r.MethodFunc("POST", "/admins", authnz(CreateAdmin))
r.MethodFunc("PATCH", "/admins/{id}", authnz(UpdateAdmin))
r.MethodFunc("DELETE", "/admins/{id}", authnz(DeleteAdmin))
// ACME responder
if acmeResponder != nil {
// ACME External Account Binding Keys
r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", acmeEABMiddleware(acmeResponder.GetExternalAccountKeys))
r.MethodFunc("GET", "/acme/eab/{provisionerName}", acmeEABMiddleware(acmeResponder.GetExternalAccountKeys))
r.MethodFunc("POST", "/acme/eab/{provisionerName}", acmeEABMiddleware(acmeResponder.CreateExternalAccountKey))
r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", acmeEABMiddleware(acmeResponder.DeleteExternalAccountKey))
}
// Policy responder
if policyResponder != nil {
// Policy - Authority
r.MethodFunc("GET", "/policy", authorityPolicyMiddleware(policyResponder.GetAuthorityPolicy))
r.MethodFunc("POST", "/policy", authorityPolicyMiddleware(policyResponder.CreateAuthorityPolicy))
r.MethodFunc("PUT", "/policy", authorityPolicyMiddleware(policyResponder.UpdateAuthorityPolicy))
r.MethodFunc("DELETE", "/policy", authorityPolicyMiddleware(policyResponder.DeleteAuthorityPolicy))
// Policy - Provisioner
r.MethodFunc("GET", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.GetProvisionerPolicy))
r.MethodFunc("POST", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.CreateProvisionerPolicy))
r.MethodFunc("PUT", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.UpdateProvisionerPolicy))
r.MethodFunc("DELETE", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.DeleteProvisionerPolicy))
// Policy - ACME Account
r.MethodFunc("GET", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.GetACMEAccountPolicy))
r.MethodFunc("GET", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.GetACMEAccountPolicy))
r.MethodFunc("POST", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.CreateACMEAccountPolicy))
r.MethodFunc("POST", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.CreateACMEAccountPolicy))
r.MethodFunc("PUT", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.UpdateACMEAccountPolicy))
r.MethodFunc("PUT", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.UpdateACMEAccountPolicy))
r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.DeleteACMEAccountPolicy))
r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.DeleteACMEAccountPolicy))
}
}

@ -1,22 +1,26 @@
package api
import (
"context"
"errors"
"net/http"
"github.com/go-chi/chi"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/admin/db/nosql"
"github.com/smallstep/certificates/authority/provisioner"
)
type nextHTTP = func(http.ResponseWriter, *http.Request)
// requireAPIEnabled is a middleware that ensures the Administration API
// is enabled before servicing requests.
func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP {
func requireAPIEnabled(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !h.auth.IsAdminAPIEnabled() {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType,
"administration API not enabled"))
if !mustAuthority(r.Context()).IsAdminAPIEnabled() {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "administration API not enabled"))
return
}
next(w, r)
@ -24,8 +28,9 @@ func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP {
}
// extractAuthorizeTokenAdmin is a middleware that extracts and caches the bearer token.
func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP {
func extractAuthorizeTokenAdmin(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
tok := r.Header.Get("Authorization")
if tok == "" {
render.Error(w, admin.NewError(admin.ErrorUnauthorizedType,
@ -33,22 +38,111 @@ func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP {
return
}
adm, err := h.auth.AuthorizeAdminToken(r, tok)
ctx := r.Context()
adm, err := mustAuthority(ctx).AuthorizeAdminToken(r, tok)
if err != nil {
render.Error(w, err)
return
}
ctx := context.WithValue(r.Context(), adminContextKey, adm)
ctx = linkedca.NewContextWithAdmin(ctx, adm)
next(w, r.WithContext(ctx))
}
}
// ContextKey is the key type for storing and searching for ACME request
// essentials in the context of a request.
type ContextKey string
// loadProvisionerByName is a middleware that searches for a provisioner
// by name and stores it in the context.
func loadProvisionerByName(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var (
p provisioner.Interface
err error
)
const (
// adminContextKey account key
adminContextKey = ContextKey("admin")
)
ctx := r.Context()
auth := mustAuthority(ctx)
adminDB := admin.MustFromContext(ctx)
name := chi.URLParam(r, "provisionerName")
// TODO(hs): distinguish 404 vs. 500
if p, err = auth.LoadProvisionerByName(name); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
return
}
prov, err := adminDB.GetProvisioner(ctx, p.GetID())
if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error retrieving provisioner %s", name))
return
}
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
next(w, r.WithContext(ctx))
}
}
// checkAction checks if an action is supported in standalone or not
func checkAction(next http.HandlerFunc, supportedInStandalone bool) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// actions allowed in standalone mode are always supported
if supportedInStandalone {
next(w, r)
return
}
// when an action is not supported in standalone mode and when
// using a nosql.DB backend, actions are not supported
if _, ok := admin.MustFromContext(r.Context()).(*nosql.DB); ok {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType,
"operation not supported in standalone mode"))
return
}
// continue to next http handler
next(w, r)
}
}
// loadExternalAccountKey is a middleware that searches for an ACME
// External Account Key by reference or keyID and stores it in the context.
func loadExternalAccountKey(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
prov := linkedca.MustProvisionerFromContext(ctx)
acmeDB := acme.MustDatabaseFromContext(ctx)
reference := chi.URLParam(r, "reference")
keyID := chi.URLParam(r, "keyID")
var (
eak *acme.ExternalAccountKey
err error
)
if keyID != "" {
eak, err = acmeDB.GetExternalAccountKey(ctx, prov.GetId(), keyID)
} else {
eak, err = acmeDB.GetExternalAccountKeyByReference(ctx, prov.GetId(), reference)
}
if err != nil {
if errors.Is(err, acme.ErrNotFound) {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found"))
return
}
render.Error(w, admin.WrapErrorISE(err, "error retrieving ACME External Account Key"))
return
}
if eak == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found"))
return
}
linkedEAK := eakToLinked(eak)
ctx = linkedca.NewContextWithExternalAccountKey(ctx, linkedEAK)
next(w, r.WithContext(ctx))
}
}

@ -4,25 +4,32 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/protobuf/types/known/timestamppb"
"go.step.sm/linkedca"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/admin"
"go.step.sm/linkedca"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/smallstep/certificates/authority/admin/db/nosql"
"github.com/smallstep/certificates/authority/provisioner"
)
func TestHandler_requireAPIEnabled(t *testing.T) {
type test struct {
ctx context.Context
auth adminAuthority
next nextHTTP
next http.HandlerFunc
err *admin.Error
statusCode int
}
@ -64,13 +71,11 @@ func TestHandler_requireAPIEnabled(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
}
mockMustAuthority(t, tc.auth)
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.requireAPIEnabled(tc.next)(w, req)
requireAPIEnabled(tc.next)(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
@ -102,7 +107,7 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
ctx context.Context
auth adminAuthority
req *http.Request
next nextHTTP
next http.HandlerFunc
err *admin.Error
statusCode int
}
@ -152,7 +157,7 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
req.Header["Authorization"] = []string{"token"}
createdAt := time.Now()
var deletedAt time.Time
admin := &linkedca.Admin{
adm := &linkedca.Admin{
Id: "adminID",
AuthorityId: "authorityID",
Subject: "admin",
@ -164,20 +169,15 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
auth := &mockAdminAuthority{
MockAuthorizeAdminToken: func(r *http.Request, token string) (*linkedca.Admin, error) {
assert.Equals(t, "token", token)
return admin, nil
return adm, nil
},
}
next := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
a := ctx.Value(adminContextKey) // verifying that the context now has a linkedca.Admin
adm, ok := a.(*linkedca.Admin)
if !ok {
t.Errorf("expected *linkedca.Admin; got %T", a)
return
}
adm := linkedca.MustAdminFromContext(ctx) // verifying that the context now has a linkedca.Admin
opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})}
if !cmp.Equal(admin, adm, opts...) {
t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(admin, adm, opts...))
if !cmp.Equal(adm, adm, opts...) {
t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(adm, adm, opts...))
}
w.Write(nil) // mock response with status 200
}
@ -194,13 +194,459 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
mockMustAuthority(t, tc.auth)
req := tc.req.WithContext(tc.ctx)
w := httptest.NewRecorder()
extractAuthorizeTokenAdmin(tc.next)(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
body, err := io.ReadAll(res.Body)
res.Body.Close()
assert.FatalError(t, err)
if res.StatusCode >= 400 {
err := admin.Error{}
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err))
assert.Equals(t, tc.err.Type, err.Type)
assert.Equals(t, tc.err.Message, err.Message)
assert.Equals(t, tc.err.StatusCode(), res.StatusCode)
assert.Equals(t, tc.err.Detail, err.Detail)
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
return
}
})
}
}
req := tc.req.WithContext(tc.ctx)
func TestHandler_loadProvisionerByName(t *testing.T) {
type test struct {
adminDB admin.DB
auth adminAuthority
ctx context.Context
next http.HandlerFunc
err *admin.Error
statusCode int
}
var tests = map[string]func(t *testing.T) test{
"fail/auth.LoadProvisionerByName": func(t *testing.T) test {
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("provisionerName", "provName")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return nil, errors.New("force")
},
}
err := admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName")
err.Message = "error loading provisioner provName: force"
return test{
ctx: ctx,
auth: auth,
adminDB: &admin.MockDB{},
statusCode: 500,
err: err,
}
},
"fail/db.GetProvisioner": func(t *testing.T) test {
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("provisionerName", "provName")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return &provisioner.MockProvisioner{
MgetID: func() string {
return "provID"
},
}, nil
},
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return nil, errors.New("force")
},
}
err := admin.WrapErrorISE(errors.New("force"), "error retrieving provisioner provName")
err.Message = "error retrieving provisioner provName: force"
return test{
ctx: ctx,
auth: auth,
adminDB: db,
statusCode: 500,
err: err,
}
},
"ok": func(t *testing.T) test {
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("provisionerName", "provName")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return &provisioner.MockProvisioner{
MgetID: func() string {
return "provID"
},
}, nil
},
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return &linkedca.Provisioner{
Id: "provID",
Name: "provName",
}, nil
},
}
return test{
ctx: ctx,
auth: auth,
adminDB: db,
statusCode: 200,
next: func(w http.ResponseWriter, r *http.Request) {
prov := linkedca.MustProvisionerFromContext(r.Context())
assert.NotNil(t, prov)
assert.Equals(t, "provID", prov.GetId())
assert.Equals(t, "provName", prov.GetName())
w.Write(nil) // mock response with status 200
},
}
},
}
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
mockMustAuthority(t, tc.auth)
ctx := admin.NewContext(tc.ctx, tc.adminDB)
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
req = req.WithContext(ctx)
w := httptest.NewRecorder()
loadProvisionerByName(tc.next)(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
body, err := io.ReadAll(res.Body)
res.Body.Close()
assert.FatalError(t, err)
if res.StatusCode >= 400 {
err := admin.Error{}
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err))
assert.Equals(t, tc.err.Type, err.Type)
assert.Equals(t, tc.err.Message, err.Message)
assert.Equals(t, tc.err.StatusCode(), res.StatusCode)
assert.Equals(t, tc.err.Detail, err.Detail)
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
return
}
})
}
}
func TestHandler_checkAction(t *testing.T) {
type test struct {
adminDB admin.DB
next http.HandlerFunc
supportedInStandalone bool
err *admin.Error
statusCode int
}
var tests = map[string]func(t *testing.T) test{
"standalone-nosql-supported": func(t *testing.T) test {
return test{
supportedInStandalone: true,
adminDB: &nosql.DB{},
next: func(w http.ResponseWriter, r *http.Request) {
w.Write(nil) // mock response with status 200
},
statusCode: 200,
}
},
"standalone-nosql-not-supported": func(t *testing.T) test {
err := admin.NewError(admin.ErrorNotImplementedType, "operation not supported in standalone mode")
err.Message = "operation not supported in standalone mode"
return test{
supportedInStandalone: false,
adminDB: &nosql.DB{},
statusCode: 501,
err: err,
}
},
"standalone-no-nosql-not-supported": func(t *testing.T) test {
err := admin.NewError(admin.ErrorNotImplementedType, "operation not supported")
err.Message = "operation not supported"
return test{
supportedInStandalone: false,
adminDB: &admin.MockDB{},
next: func(w http.ResponseWriter, r *http.Request) {
w.Write(nil) // mock response with status 200
},
statusCode: 200,
err: err,
}
},
}
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
ctx := admin.NewContext(context.Background(), tc.adminDB)
req := httptest.NewRequest("GET", "/foo", nil).WithContext(ctx)
w := httptest.NewRecorder()
checkAction(tc.next, tc.supportedInStandalone)(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
body, err := io.ReadAll(res.Body)
res.Body.Close()
assert.FatalError(t, err)
if res.StatusCode >= 400 {
err := admin.Error{}
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err))
assert.Equals(t, tc.err.Type, err.Type)
assert.Equals(t, tc.err.Message, err.Message)
assert.Equals(t, tc.err.StatusCode(), res.StatusCode)
assert.Equals(t, tc.err.Detail, err.Detail)
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
return
}
})
}
}
func TestHandler_loadExternalAccountKey(t *testing.T) {
type test struct {
ctx context.Context
acmeDB acme.DB
next http.HandlerFunc
err *admin.Error
statusCode int
}
var tests = map[string]func(t *testing.T) test{
"fail/keyID-not-found-error": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
}
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("keyID", "key")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")
err.Message = "ACME External Account Key not found"
return test{
ctx: ctx,
acmeDB: &acme.MockDB{
MockGetExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, "provID", provisionerID)
assert.Equals(t, "key", keyID)
return nil, acme.ErrNotFound
},
},
err: err,
statusCode: 404,
}
},
"fail/keyID-error": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
}
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("keyID", "key")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
err := admin.WrapErrorISE(errors.New("force"), "error retrieving ACME External Account Key")
err.Message = "error retrieving ACME External Account Key: force"
return test{
ctx: ctx,
acmeDB: &acme.MockDB{
MockGetExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, "provID", provisionerID)
assert.Equals(t, "key", keyID)
return nil, errors.New("force")
},
},
err: err,
statusCode: 500,
}
},
"fail/reference-not-found-error": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
}
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("reference", "ref")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")
err.Message = "ACME External Account Key not found"
return test{
ctx: ctx,
acmeDB: &acme.MockDB{
MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, "provID", provisionerID)
assert.Equals(t, "ref", reference)
return nil, acme.ErrNotFound
},
},
err: err,
statusCode: 404,
}
},
"fail/reference-error": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
}
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("reference", "ref")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
err := admin.WrapErrorISE(errors.New("force"), "error retrieving ACME External Account Key")
err.Message = "error retrieving ACME External Account Key: force"
return test{
ctx: ctx,
acmeDB: &acme.MockDB{
MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, "provID", provisionerID)
assert.Equals(t, "ref", reference)
return nil, errors.New("force")
},
},
err: err,
statusCode: 500,
}
},
"fail/no-key": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
}
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("reference", "ref")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")
err.Message = "ACME External Account Key not found"
return test{
ctx: ctx,
acmeDB: &acme.MockDB{
MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, "provID", provisionerID)
assert.Equals(t, "ref", reference)
return nil, nil
},
},
err: err,
statusCode: 404,
}
},
"ok/keyID": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
}
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("keyID", "eakID")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")
err.Message = "ACME External Account Key not found"
createdAt := time.Now().Add(-1 * time.Hour)
var boundAt time.Time
eak := &acme.ExternalAccountKey{
ID: "eakID",
ProvisionerID: "provID",
CreatedAt: createdAt,
BoundAt: boundAt,
}
return test{
ctx: ctx,
acmeDB: &acme.MockDB{
MockGetExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, "provID", provisionerID)
assert.Equals(t, "eakID", keyID)
return eak, nil
},
},
next: func(w http.ResponseWriter, r *http.Request) {
contextEAK := linkedca.MustExternalAccountKeyFromContext(r.Context())
assert.NotNil(t, eak)
exp := &linkedca.EABKey{
Id: "eakID",
Provisioner: "provID",
CreatedAt: timestamppb.New(createdAt),
BoundAt: timestamppb.New(boundAt),
}
assert.Equals(t, exp, contextEAK)
w.Write(nil) // mock response with status 200
},
err: nil,
statusCode: 200,
}
},
"ok/reference": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
}
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("reference", "ref")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")
err.Message = "ACME External Account Key not found"
createdAt := time.Now().Add(-1 * time.Hour)
var boundAt time.Time
eak := &acme.ExternalAccountKey{
ID: "eakID",
ProvisionerID: "provID",
Reference: "ref",
CreatedAt: createdAt,
BoundAt: boundAt,
}
return test{
ctx: ctx,
acmeDB: &acme.MockDB{
MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, "provID", provisionerID)
assert.Equals(t, "ref", reference)
return eak, nil
},
},
next: func(w http.ResponseWriter, r *http.Request) {
contextEAK := linkedca.MustExternalAccountKeyFromContext(r.Context())
assert.NotNil(t, eak)
exp := &linkedca.EABKey{
Id: "eakID",
Provisioner: "provID",
Reference: "ref",
CreatedAt: timestamppb.New(createdAt),
BoundAt: timestamppb.New(boundAt),
}
assert.Equals(t, exp, contextEAK)
w.Write(nil) // mock response with status 200
},
err: nil,
statusCode: 200,
}
},
}
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
ctx := acme.NewDatabaseContext(tc.ctx, tc.acmeDB)
req := httptest.NewRequest("GET", "/foo", nil)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
h.extractAuthorizeTokenAdmin(tc.next)(w, req)
loadExternalAccountKey(tc.next)(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)

@ -0,0 +1,496 @@
package api
import (
"context"
"errors"
"net/http"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/policy"
)
// PolicyAdminResponder is the interface responsible for writing ACME admin
// responses.
type PolicyAdminResponder interface {
GetAuthorityPolicy(w http.ResponseWriter, r *http.Request)
CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request)
UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request)
DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request)
GetProvisionerPolicy(w http.ResponseWriter, r *http.Request)
CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request)
UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request)
DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request)
GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
}
// policyAdminResponder implements PolicyAdminResponder.
type policyAdminResponder struct{}
// NewACMEAdminResponder returns a new PolicyAdminResponder.
func NewPolicyAdminResponder() PolicyAdminResponder {
return &policyAdminResponder{}
}
// GetAuthorityPolicy handles the GET /admin/authority/policy request
func (par *policyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err)
return
}
auth := mustAuthority(ctx)
authorityPolicy, err := auth.GetAuthorityPolicy(r.Context())
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy"))
return
}
if authorityPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist"))
return
}
render.ProtoJSONStatus(w, authorityPolicy, http.StatusOK)
}
// CreateAuthorityPolicy handles the POST /admin/authority/policy request
func (par *policyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err)
return
}
auth := mustAuthority(ctx)
authorityPolicy, err := auth.GetAuthorityPolicy(ctx)
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy"))
return
}
if authorityPolicy != nil {
adminErr := admin.NewError(admin.ErrorConflictType, "authority already has a policy")
render.Error(w, adminErr)
return
}
var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return
}
newPolicy.Deduplicate()
if err := validatePolicy(newPolicy); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy"))
return
}
adm := linkedca.MustAdminFromContext(ctx)
var createdPolicy *linkedca.Policy
if createdPolicy, err = auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
if isBadRequest(err) {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy"))
return
}
render.Error(w, admin.WrapErrorISE(err, "error storing authority policy"))
return
}
render.ProtoJSONStatus(w, createdPolicy, http.StatusCreated)
}
// UpdateAuthorityPolicy handles the PUT /admin/authority/policy request
func (par *policyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err)
return
}
auth := mustAuthority(ctx)
authorityPolicy, err := auth.GetAuthorityPolicy(ctx)
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy"))
return
}
if authorityPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist"))
return
}
var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return
}
newPolicy.Deduplicate()
if err := validatePolicy(newPolicy); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy"))
return
}
adm := linkedca.MustAdminFromContext(ctx)
var updatedPolicy *linkedca.Policy
if updatedPolicy, err = auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
if isBadRequest(err) {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy"))
return
}
render.Error(w, admin.WrapErrorISE(err, "error updating authority policy"))
return
}
render.ProtoJSONStatus(w, updatedPolicy, http.StatusOK)
}
// DeleteAuthorityPolicy handles the DELETE /admin/authority/policy request
func (par *policyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err)
return
}
auth := mustAuthority(ctx)
authorityPolicy, err := auth.GetAuthorityPolicy(ctx)
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy"))
return
}
if authorityPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist"))
return
}
if err := auth.RemoveAuthorityPolicy(ctx); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error deleting authority policy"))
return
}
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
}
// GetProvisionerPolicy handles the GET /admin/provisioners/{name}/policy request
func (par *policyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err)
return
}
prov := linkedca.MustProvisionerFromContext(ctx)
provisionerPolicy := prov.GetPolicy()
if provisionerPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
return
}
render.ProtoJSONStatus(w, provisionerPolicy, http.StatusOK)
}
// CreateProvisionerPolicy handles the POST /admin/provisioners/{name}/policy request
func (par *policyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err)
return
}
prov := linkedca.MustProvisionerFromContext(ctx)
provisionerPolicy := prov.GetPolicy()
if provisionerPolicy != nil {
adminErr := admin.NewError(admin.ErrorConflictType, "provisioner %s already has a policy", prov.Name)
render.Error(w, adminErr)
return
}
var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return
}
newPolicy.Deduplicate()
if err := validatePolicy(newPolicy); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy"))
return
}
prov.Policy = newPolicy
auth := mustAuthority(ctx)
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
if isBadRequest(err) {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner policy"))
return
}
render.Error(w, admin.WrapErrorISE(err, "error creating provisioner policy"))
return
}
render.ProtoJSONStatus(w, newPolicy, http.StatusCreated)
}
// UpdateProvisionerPolicy handles the PUT /admin/provisioners/{name}/policy request
func (par *policyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err)
return
}
prov := linkedca.MustProvisionerFromContext(ctx)
provisionerPolicy := prov.GetPolicy()
if provisionerPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
return
}
var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return
}
newPolicy.Deduplicate()
if err := validatePolicy(newPolicy); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy"))
return
}
prov.Policy = newPolicy
auth := mustAuthority(ctx)
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
if isBadRequest(err) {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner policy"))
return
}
render.Error(w, admin.WrapErrorISE(err, "error updating provisioner policy"))
return
}
render.ProtoJSONStatus(w, newPolicy, http.StatusOK)
}
// DeleteProvisionerPolicy handles the DELETE /admin/provisioners/{name}/policy request
func (par *policyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err)
return
}
prov := linkedca.MustProvisionerFromContext(ctx)
if prov.Policy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
return
}
// remove the policy
prov.Policy = nil
auth := mustAuthority(ctx)
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error deleting provisioner policy"))
return
}
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
}
func (par *policyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err)
return
}
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
eakPolicy := eak.GetPolicy()
if eakPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
return
}
render.ProtoJSONStatus(w, eakPolicy, http.StatusOK)
}
func (par *policyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err)
return
}
prov := linkedca.MustProvisionerFromContext(ctx)
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
eakPolicy := eak.GetPolicy()
if eakPolicy != nil {
adminErr := admin.NewError(admin.ErrorConflictType, "ACME EAK %s already has a policy", eak.Id)
render.Error(w, adminErr)
return
}
var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return
}
newPolicy.Deduplicate()
if err := validatePolicy(newPolicy); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy"))
return
}
eak.Policy = newPolicy
acmeEAK := linkedEAKToCertificates(eak)
acmeDB := acme.MustDatabaseFromContext(ctx)
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error creating ACME EAK policy"))
return
}
render.ProtoJSONStatus(w, newPolicy, http.StatusCreated)
}
func (par *policyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err)
return
}
prov := linkedca.MustProvisionerFromContext(ctx)
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
eakPolicy := eak.GetPolicy()
if eakPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
return
}
var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return
}
newPolicy.Deduplicate()
if err := validatePolicy(newPolicy); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy"))
return
}
eak.Policy = newPolicy
acmeEAK := linkedEAKToCertificates(eak)
acmeDB := acme.MustDatabaseFromContext(ctx)
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error updating ACME EAK policy"))
return
}
render.ProtoJSONStatus(w, newPolicy, http.StatusOK)
}
func (par *policyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err)
return
}
prov := linkedca.MustProvisionerFromContext(ctx)
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
eakPolicy := eak.GetPolicy()
if eakPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
return
}
// remove the policy
eak.Policy = nil
acmeEAK := linkedEAKToCertificates(eak)
acmeDB := acme.MustDatabaseFromContext(ctx)
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error deleting ACME EAK policy"))
return
}
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
}
// blockLinkedCA blocks all API operations on linked deployments
func blockLinkedCA(ctx context.Context) error {
// temporary blocking linked deployments
adminDB := admin.MustFromContext(ctx)
if a, ok := adminDB.(interface{ IsLinkedCA() bool }); ok && a.IsLinkedCA() {
return admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments")
}
return nil
}
// isBadRequest checks if an error should result in a bad request error
// returned to the client.
func isBadRequest(err error) bool {
var pe *authority.PolicyError
isPolicyError := errors.As(err, &pe)
return isPolicyError && (pe.Typ == authority.AdminLockOut || pe.Typ == authority.EvaluationFailure || pe.Typ == authority.ConfigurationFailure)
}
func validatePolicy(p *linkedca.Policy) error {
// convert the policy; return early if nil
options := policy.LinkedToCertificates(p)
if options == nil {
return nil
}
var err error
// Initialize a temporary x509 allow/deny policy engine
if _, err = policy.NewX509PolicyEngine(options.GetX509Options()); err != nil {
return err
}
// Initialize a temporary SSH allow/deny policy engine for host certificates
if _, err = policy.NewSSHHostPolicyEngine(options.GetSSHOptions()); err != nil {
return err
}
// Initialize a temporary SSH allow/deny policy engine for user certificates
if _, err = policy.NewSSHUserPolicyEngine(options.GetSSHOptions()); err != nil {
return err
}
return nil
}

File diff suppressed because it is too large Load Diff

@ -23,29 +23,31 @@ type GetProvisionersResponse struct {
}
// GetProvisioner returns the requested provisioner, or an error.
func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
id := r.URL.Query().Get("id")
name := chi.URLParam(r, "name")
func GetProvisioner(w http.ResponseWriter, r *http.Request) {
var (
p provisioner.Interface
err error
)
ctx := r.Context()
id := r.URL.Query().Get("id")
name := chi.URLParam(r, "name")
auth := mustAuthority(ctx)
db := admin.MustFromContext(ctx)
if len(id) > 0 {
if p, err = h.auth.LoadProvisionerByID(id); err != nil {
if p, err = auth.LoadProvisionerByID(id); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
return
}
} else {
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
if p, err = auth.LoadProvisionerByName(name); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
return
}
}
prov, err := h.adminDB.GetProvisioner(ctx, p.GetID())
prov, err := db.GetProvisioner(ctx, p.GetID())
if err != nil {
render.Error(w, err)
return
@ -54,7 +56,7 @@ func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) {
}
// GetProvisioners returns the given segment of provisioners associated with the authority.
func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
func GetProvisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := api.ParseCursor(r)
if err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
@ -62,7 +64,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
return
}
p, next, err := h.auth.GetProvisioners(cursor, limit)
p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit)
if err != nil {
render.Error(w, errs.InternalServerErr(err))
return
@ -74,7 +76,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
}
// CreateProvisioner creates a new prov.
func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
func CreateProvisioner(w http.ResponseWriter, r *http.Request) {
var prov = new(linkedca.Provisioner)
if err := read.ProtoJSON(r.Body, prov); err != nil {
render.Error(w, err)
@ -87,7 +89,7 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
return
}
if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil {
if err := mustAuthority(r.Context()).StoreProvisioner(r.Context(), prov); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name))
return
}
@ -95,27 +97,29 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
}
// DeleteProvisioner deletes a provisioner.
func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
id := r.URL.Query().Get("id")
name := chi.URLParam(r, "name")
func DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
var (
p provisioner.Interface
err error
)
id := r.URL.Query().Get("id")
name := chi.URLParam(r, "name")
auth := mustAuthority(r.Context())
if len(id) > 0 {
if p, err = h.auth.LoadProvisionerByID(id); err != nil {
if p, err = auth.LoadProvisionerByID(id); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
return
}
} else {
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
if p, err = auth.LoadProvisionerByName(name); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
return
}
}
if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil {
if err := auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName()))
return
}
@ -124,23 +128,27 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
}
// UpdateProvisioner updates an existing prov.
func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
func UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
var nu = new(linkedca.Provisioner)
if err := read.ProtoJSON(r.Body, nu); err != nil {
render.Error(w, err)
return
}
ctx := r.Context()
name := chi.URLParam(r, "name")
_old, err := h.auth.LoadProvisionerByName(name)
auth := mustAuthority(ctx)
db := admin.MustFromContext(ctx)
p, err := auth.LoadProvisionerByName(name)
if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name))
return
}
old, err := h.adminDB.GetProvisioner(r.Context(), _old.GetID())
old, err := db.GetProvisioner(r.Context(), p.GetID())
if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID()))
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", p.GetID()))
return
}
@ -171,7 +179,7 @@ func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
return
}
if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil {
if err := auth.UpdateProvisioner(r.Context(), nu); err != nil {
render.Error(w, err)
return
}

@ -8,18 +8,21 @@ import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/go-chi/chi"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/timestamppb"
"go.step.sm/linkedca"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/linkedca"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/timestamppb"
)
func TestHandler_GetProvisioner(t *testing.T) {
@ -47,6 +50,7 @@ func TestHandler_GetProvisioner(t *testing.T) {
ctx: ctx,
req: req,
auth: auth,
adminDB: &admin.MockDB{},
statusCode: 500,
err: &admin.Error{
Type: admin.ErrorServerInternalType.String(),
@ -71,6 +75,7 @@ func TestHandler_GetProvisioner(t *testing.T) {
ctx: ctx,
req: req,
auth: auth,
adminDB: &admin.MockDB{},
statusCode: 500,
err: &admin.Error{
Type: admin.ErrorServerInternalType.String(),
@ -153,13 +158,11 @@ func TestHandler_GetProvisioner(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
adminDB: tc.adminDB,
}
req := tc.req.WithContext(tc.ctx)
mockMustAuthority(t, tc.auth)
ctx := admin.NewContext(tc.ctx, tc.adminDB)
req := tc.req.WithContext(ctx)
w := httptest.NewRecorder()
h.GetProvisioner(w, req)
GetProvisioner(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
@ -277,12 +280,10 @@ func TestHandler_GetProvisioners(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
}
mockMustAuthority(t, tc.auth)
req := tc.req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.GetProvisioners(w, req)
GetProvisioners(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
@ -335,12 +336,12 @@ func TestHandler_CreateProvisioner(t *testing.T) {
return test{
ctx: context.Background(),
body: body,
statusCode: 500,
err: &admin.Error{ // TODO(hs): this probably needs a better error
Type: "",
Status: 500,
Detail: "",
Message: "",
statusCode: 400,
err: &admin.Error{
Type: "badRequest",
Status: 400,
Detail: "bad request",
Message: "proto: syntax error (line 1:2): invalid value !",
},
}
},
@ -402,13 +403,11 @@ func TestHandler_CreateProvisioner(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
}
mockMustAuthority(t, tc.auth)
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.CreateProvisioner(w, req)
CreateProvisioner(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
@ -423,9 +422,15 @@ func TestHandler_CreateProvisioner(t *testing.T) {
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr))
assert.Equals(t, tc.err.Type, adminErr.Type)
assert.Equals(t, tc.err.Message, adminErr.Message)
assert.Equals(t, tc.err.Detail, adminErr.Detail)
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
if strings.HasPrefix(tc.err.Message, "proto:") {
assert.True(t, strings.Contains(adminErr.Message, "syntax error"))
} else {
assert.Equals(t, tc.err.Message, adminErr.Message)
}
return
}
@ -562,12 +567,10 @@ func TestHandler_DeleteProvisioner(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
}
mockMustAuthority(t, tc.auth)
req := tc.req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.DeleteProvisioner(w, req)
DeleteProvisioner(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
@ -616,12 +619,13 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
return test{
ctx: context.Background(),
body: body,
statusCode: 500,
err: &admin.Error{ // TODO(hs): this probably needs a better error
Type: "",
Status: 500,
Detail: "",
Message: "",
adminDB: &admin.MockDB{},
statusCode: 400,
err: &admin.Error{
Type: "badRequest",
Status: 400,
Detail: "bad request",
Message: "proto: syntax error (line 1:2): invalid value !",
},
}
},
@ -645,6 +649,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
return test{
ctx: ctx,
body: body,
adminDB: &admin.MockDB{},
auth: auth,
statusCode: 500,
err: &admin.Error{
@ -1052,14 +1057,12 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
adminDB: tc.adminDB,
}
mockMustAuthority(t, tc.auth)
ctx := admin.NewContext(tc.ctx, tc.adminDB)
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
req = req.WithContext(tc.ctx)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
h.UpdateProvisioner(w, req)
UpdateProvisioner(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
@ -1074,9 +1077,15 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr))
assert.Equals(t, tc.err.Type, adminErr.Type)
assert.Equals(t, tc.err.Message, adminErr.Message)
assert.Equals(t, tc.err.Detail, adminErr.Detail)
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
if strings.HasPrefix(tc.err.Message, "proto:") {
assert.True(t, strings.Contains(adminErr.Message, "syntax error"))
} else {
assert.Equals(t, tc.err.Message, adminErr.Message)
}
return
}

@ -69,6 +69,34 @@ type DB interface {
GetAdmins(ctx context.Context) ([]*linkedca.Admin, error)
UpdateAdmin(ctx context.Context, admin *linkedca.Admin) error
DeleteAdmin(ctx context.Context, id string) error
CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error
GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error)
UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error
DeleteAuthorityPolicy(ctx context.Context) error
}
type dbKey struct{}
// NewContext adds the given admin database to the context.
func NewContext(ctx context.Context, db DB) context.Context {
return context.WithValue(ctx, dbKey{}, db)
}
// FromContext returns the current admin database from the given context.
func FromContext(ctx context.Context) (db DB, ok bool) {
db, ok = ctx.Value(dbKey{}).(DB)
return
}
// MustFromContext returns the current admin database from the given context. It
// will panic if it's not in the context.
func MustFromContext(ctx context.Context) DB {
if db, ok := FromContext(ctx); !ok {
panic("admin database is not in the context")
} else {
return db
}
}
// MockDB is an implementation of the DB interface that should only be used as
@ -86,6 +114,11 @@ type MockDB struct {
MockUpdateAdmin func(ctx context.Context, adm *linkedca.Admin) error
MockDeleteAdmin func(ctx context.Context, id string) error
MockCreateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error
MockGetAuthorityPolicy func(ctx context.Context) (*linkedca.Policy, error)
MockUpdateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error
MockDeleteAuthorityPolicy func(ctx context.Context) error
MockError error
MockRet1 interface{}
}
@ -179,3 +212,35 @@ func (m *MockDB) DeleteAdmin(ctx context.Context, id string) error {
}
return m.MockError
}
// CreateAuthorityPolicy mock
func (m *MockDB) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
if m.MockCreateAuthorityPolicy != nil {
return m.MockCreateAuthorityPolicy(ctx, policy)
}
return m.MockError
}
// GetAuthorityPolicy mock
func (m *MockDB) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
if m.MockGetAuthorityPolicy != nil {
return m.MockGetAuthorityPolicy(ctx)
}
return m.MockRet1.(*linkedca.Policy), m.MockError
}
// UpdateAuthorityPolicy mock
func (m *MockDB) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
if m.MockUpdateAuthorityPolicy != nil {
return m.MockUpdateAuthorityPolicy(ctx, policy)
}
return m.MockError
}
// DeleteAuthorityPolicy mock
func (m *MockDB) DeleteAuthorityPolicy(ctx context.Context) error {
if m.MockDeleteAuthorityPolicy != nil {
return m.MockDeleteAuthorityPolicy(ctx)
}
return m.MockError
}

@ -11,8 +11,9 @@ import (
)
var (
adminsTable = []byte("admins")
provisionersTable = []byte("provisioners")
adminsTable = []byte("admins")
provisionersTable = []byte("provisioners")
authorityPoliciesTable = []byte("authority_policies")
)
// DB is a struct that implements the AdminDB interface.
@ -23,7 +24,7 @@ type DB struct {
// New configures and returns a new Authority DB backend implemented using a nosql DB.
func New(db nosqlDB.DB, authorityID string) (*DB, error) {
tables := [][]byte{adminsTable, provisionersTable}
tables := [][]byte{adminsTable, provisionersTable, authorityPoliciesTable}
for _, b := range tables {
if err := db.CreateTable(b); err != nil {
return nil, errors.Wrapf(err, "error creating table %s",

@ -0,0 +1,339 @@
package nosql
import (
"context"
"encoding/json"
"fmt"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/nosql"
)
type dbX509Policy struct {
Allow *dbX509Names `json:"allow,omitempty"`
Deny *dbX509Names `json:"deny,omitempty"`
AllowWildcardNames bool `json:"allow_wildcard_names,omitempty"`
}
type dbX509Names struct {
CommonNames []string `json:"cn,omitempty"`
DNSDomains []string `json:"dns,omitempty"`
IPRanges []string `json:"ip,omitempty"`
EmailAddresses []string `json:"email,omitempty"`
URIDomains []string `json:"uri,omitempty"`
}
type dbSSHPolicy struct {
// User contains SSH user certificate options.
User *dbSSHUserPolicy `json:"user,omitempty"`
// Host contains SSH host certificate options.
Host *dbSSHHostPolicy `json:"host,omitempty"`
}
type dbSSHHostPolicy struct {
Allow *dbSSHHostNames `json:"allow,omitempty"`
Deny *dbSSHHostNames `json:"deny,omitempty"`
}
type dbSSHHostNames struct {
DNSDomains []string `json:"dns,omitempty"`
IPRanges []string `json:"ip,omitempty"`
Principals []string `json:"principal,omitempty"`
}
type dbSSHUserPolicy struct {
Allow *dbSSHUserNames `json:"allow,omitempty"`
Deny *dbSSHUserNames `json:"deny,omitempty"`
}
type dbSSHUserNames struct {
EmailAddresses []string `json:"email,omitempty"`
Principals []string `json:"principal,omitempty"`
}
type dbPolicy struct {
X509 *dbX509Policy `json:"x509,omitempty"`
SSH *dbSSHPolicy `json:"ssh,omitempty"`
}
type dbAuthorityPolicy struct {
ID string `json:"id"`
AuthorityID string `json:"authorityID"`
Policy *dbPolicy `json:"policy,omitempty"`
}
func (dbap *dbAuthorityPolicy) convert() *linkedca.Policy {
if dbap == nil {
return nil
}
return dbToLinked(dbap.Policy)
}
func (db *DB) getDBAuthorityPolicyBytes(ctx context.Context, authorityID string) ([]byte, error) {
data, err := db.db.Get(authorityPoliciesTable, []byte(authorityID))
if nosql.IsErrNotFound(err) {
return nil, admin.NewError(admin.ErrorNotFoundType, "authority policy not found")
} else if err != nil {
return nil, fmt.Errorf("error loading authority policy: %w", err)
}
return data, nil
}
func (db *DB) unmarshalDBAuthorityPolicy(data []byte) (*dbAuthorityPolicy, error) {
if len(data) == 0 {
return nil, nil
}
var dba = new(dbAuthorityPolicy)
if err := json.Unmarshal(data, dba); err != nil {
return nil, fmt.Errorf("error unmarshaling policy bytes into dbAuthorityPolicy: %w", err)
}
return dba, nil
}
func (db *DB) getDBAuthorityPolicy(ctx context.Context, authorityID string) (*dbAuthorityPolicy, error) {
data, err := db.getDBAuthorityPolicyBytes(ctx, authorityID)
if err != nil {
return nil, err
}
dbap, err := db.unmarshalDBAuthorityPolicy(data)
if err != nil {
return nil, err
}
if dbap == nil {
return nil, nil
}
if dbap.AuthorityID != authorityID {
return nil, admin.NewError(admin.ErrorAuthorityMismatchType,
"authority policy is not owned by authority %s", authorityID)
}
return dbap, nil
}
func (db *DB) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
dbap := &dbAuthorityPolicy{
ID: db.authorityID,
AuthorityID: db.authorityID,
Policy: linkedToDB(policy),
}
if err := db.save(ctx, dbap.ID, dbap, nil, "authority_policy", authorityPoliciesTable); err != nil {
return admin.WrapErrorISE(err, "error creating authority policy")
}
return nil
}
func (db *DB) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
dbap, err := db.getDBAuthorityPolicy(ctx, db.authorityID)
if err != nil {
return nil, err
}
return dbap.convert(), nil
}
func (db *DB) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
old, err := db.getDBAuthorityPolicy(ctx, db.authorityID)
if err != nil {
return err
}
dbap := &dbAuthorityPolicy{
ID: db.authorityID,
AuthorityID: db.authorityID,
Policy: linkedToDB(policy),
}
if err := db.save(ctx, dbap.ID, dbap, old, "authority_policy", authorityPoliciesTable); err != nil {
return admin.WrapErrorISE(err, "error updating authority policy")
}
return nil
}
func (db *DB) DeleteAuthorityPolicy(ctx context.Context) error {
old, err := db.getDBAuthorityPolicy(ctx, db.authorityID)
if err != nil {
return err
}
if err := db.save(ctx, old.ID, nil, old, "authority_policy", authorityPoliciesTable); err != nil {
return admin.WrapErrorISE(err, "error deleting authority policy")
}
return nil
}
func dbToLinked(p *dbPolicy) *linkedca.Policy {
if p == nil {
return nil
}
r := &linkedca.Policy{}
if x509 := p.X509; x509 != nil {
r.X509 = &linkedca.X509Policy{}
if allow := x509.Allow; allow != nil {
r.X509.Allow = &linkedca.X509Names{}
r.X509.Allow.Dns = allow.DNSDomains
r.X509.Allow.Emails = allow.EmailAddresses
r.X509.Allow.Ips = allow.IPRanges
r.X509.Allow.Uris = allow.URIDomains
r.X509.Allow.CommonNames = allow.CommonNames
}
if deny := x509.Deny; deny != nil {
r.X509.Deny = &linkedca.X509Names{}
r.X509.Deny.Dns = deny.DNSDomains
r.X509.Deny.Emails = deny.EmailAddresses
r.X509.Deny.Ips = deny.IPRanges
r.X509.Deny.Uris = deny.URIDomains
r.X509.Deny.CommonNames = deny.CommonNames
}
r.X509.AllowWildcardNames = x509.AllowWildcardNames
}
if ssh := p.SSH; ssh != nil {
r.Ssh = &linkedca.SSHPolicy{}
if host := ssh.Host; host != nil {
r.Ssh.Host = &linkedca.SSHHostPolicy{}
if allow := host.Allow; allow != nil {
r.Ssh.Host.Allow = &linkedca.SSHHostNames{}
r.Ssh.Host.Allow.Dns = allow.DNSDomains
r.Ssh.Host.Allow.Ips = allow.IPRanges
r.Ssh.Host.Allow.Principals = allow.Principals
}
if deny := host.Deny; deny != nil {
r.Ssh.Host.Deny = &linkedca.SSHHostNames{}
r.Ssh.Host.Deny.Dns = deny.DNSDomains
r.Ssh.Host.Deny.Ips = deny.IPRanges
r.Ssh.Host.Deny.Principals = deny.Principals
}
}
if user := ssh.User; user != nil {
r.Ssh.User = &linkedca.SSHUserPolicy{}
if allow := user.Allow; allow != nil {
r.Ssh.User.Allow = &linkedca.SSHUserNames{}
r.Ssh.User.Allow.Emails = allow.EmailAddresses
r.Ssh.User.Allow.Principals = allow.Principals
}
if deny := user.Deny; deny != nil {
r.Ssh.User.Deny = &linkedca.SSHUserNames{}
r.Ssh.User.Deny.Emails = deny.EmailAddresses
r.Ssh.User.Deny.Principals = deny.Principals
}
}
}
return r
}
func linkedToDB(p *linkedca.Policy) *dbPolicy {
if p == nil {
return nil
}
// return early if x509 nor SSH is set
if p.GetX509() == nil && p.GetSsh() == nil {
return nil
}
r := &dbPolicy{}
// fill x509 policy configuration
if x509 := p.GetX509(); x509 != nil {
r.X509 = &dbX509Policy{}
if allow := x509.GetAllow(); allow != nil {
r.X509.Allow = &dbX509Names{}
if allow.Dns != nil {
r.X509.Allow.DNSDomains = allow.Dns
}
if allow.Ips != nil {
r.X509.Allow.IPRanges = allow.Ips
}
if allow.Emails != nil {
r.X509.Allow.EmailAddresses = allow.Emails
}
if allow.Uris != nil {
r.X509.Allow.URIDomains = allow.Uris
}
if allow.CommonNames != nil {
r.X509.Allow.CommonNames = allow.CommonNames
}
}
if deny := x509.GetDeny(); deny != nil {
r.X509.Deny = &dbX509Names{}
if deny.Dns != nil {
r.X509.Deny.DNSDomains = deny.Dns
}
if deny.Ips != nil {
r.X509.Deny.IPRanges = deny.Ips
}
if deny.Emails != nil {
r.X509.Deny.EmailAddresses = deny.Emails
}
if deny.Uris != nil {
r.X509.Deny.URIDomains = deny.Uris
}
if deny.CommonNames != nil {
r.X509.Deny.CommonNames = deny.CommonNames
}
}
r.X509.AllowWildcardNames = x509.GetAllowWildcardNames()
}
// fill ssh policy configuration
if ssh := p.GetSsh(); ssh != nil {
r.SSH = &dbSSHPolicy{}
if host := ssh.GetHost(); host != nil {
r.SSH.Host = &dbSSHHostPolicy{}
if allow := host.GetAllow(); allow != nil {
r.SSH.Host.Allow = &dbSSHHostNames{}
if allow.Dns != nil {
r.SSH.Host.Allow.DNSDomains = allow.Dns
}
if allow.Ips != nil {
r.SSH.Host.Allow.IPRanges = allow.Ips
}
if allow.Principals != nil {
r.SSH.Host.Allow.Principals = allow.Principals
}
}
if deny := host.GetDeny(); deny != nil {
r.SSH.Host.Deny = &dbSSHHostNames{}
if deny.Dns != nil {
r.SSH.Host.Deny.DNSDomains = deny.Dns
}
if deny.Ips != nil {
r.SSH.Host.Deny.IPRanges = deny.Ips
}
if deny.Principals != nil {
r.SSH.Host.Deny.Principals = deny.Principals
}
}
}
if user := ssh.GetUser(); user != nil {
r.SSH.User = &dbSSHUserPolicy{}
if allow := user.GetAllow(); allow != nil {
r.SSH.User.Allow = &dbSSHUserNames{}
if allow.Emails != nil {
r.SSH.User.Allow.EmailAddresses = allow.Emails
}
if allow.Principals != nil {
r.SSH.User.Allow.Principals = allow.Principals
}
}
if deny := user.GetDeny(); deny != nil {
r.SSH.User.Deny = &dbSSHUserNames{}
if deny.Emails != nil {
r.SSH.User.Deny.EmailAddresses = deny.Emails
}
if deny.Principals != nil {
r.SSH.User.Deny.Principals = deny.Principals
}
}
}
}
return r
}

File diff suppressed because it is too large Load Diff

@ -24,10 +24,12 @@ const (
ErrorBadRequestType
// ErrorNotImplementedType not implemented.
ErrorNotImplementedType
// ErrorUnauthorizedType internal server error.
// ErrorUnauthorizedType unauthorized.
ErrorUnauthorizedType
// ErrorServerInternalType internal server error.
ErrorServerInternalType
// ErrorConflictType conflict.
ErrorConflictType
)
// String returns the string representation of the admin problem type,
@ -48,6 +50,8 @@ func (ap ProblemType) String() string {
return "unauthorized"
case ErrorServerInternalType:
return "internalServerError"
case ErrorConflictType:
return "conflict"
default:
return fmt.Sprintf("unsupported error type '%d'", int(ap))
}
@ -64,7 +68,7 @@ var (
errorServerInternalMetadata = errorMetadata{
typ: ErrorServerInternalType.String(),
details: "the server experienced an internal error",
status: 500,
status: http.StatusInternalServerError,
}
errorMap = map[ProblemType]errorMetadata{
ErrorNotFoundType: {
@ -98,6 +102,11 @@ var (
status: http.StatusUnauthorized,
},
ErrorServerInternalType: errorServerInternalMetadata,
ErrorConflictType: {
typ: ErrorConflictType.String(),
details: "conflict",
status: http.StatusConflict,
},
}
)

@ -59,12 +59,12 @@ func newSubProv(subject, prov string) subProv {
return subProv{subject, prov}
}
// LoadBySubProv a admin by the subject and provisioner name.
// LoadBySubProv loads an admin by subject and provisioner name.
func (c *Collection) LoadBySubProv(sub, provName string) (*linkedca.Admin, bool) {
return loadAdmin(c.bySubProv, newSubProv(sub, provName))
}
// LoadByProvisioner a admin by the subject and provisioner name.
// LoadByProvisioner loads admins by provisioner name.
func (c *Collection) LoadByProvisioner(provName string) ([]*linkedca.Admin, bool) {
val, ok := c.byProv.Load(provName)
if !ok {
@ -78,7 +78,7 @@ func (c *Collection) LoadByProvisioner(provName string) ([]*linkedca.Admin, bool
}
// Store adds an admin to the collection and enforces the uniqueness of
// admin IDs and amdin subject <-> provisioner name combos.
// admin IDs and admin subject <-> provisioner name combos.
func (c *Collection) Store(adm *linkedca.Admin, prov provisioner.Interface) error {
// Input validation.
if adm.ProvisionerId != prov.GetID() {

@ -49,7 +49,7 @@ func (a *Authority) StoreAdmin(ctx context.Context, adm *linkedca.Admin, prov pr
return admin.WrapErrorISE(err, "error creating admin")
}
if err := a.admins.Store(adm, prov); err != nil {
if err := a.reloadAdminResources(ctx); err != nil {
if err := a.ReloadAdminResources(ctx); err != nil {
return admin.WrapErrorISE(err, "error reloading admin resources on failed admin store")
}
return admin.WrapErrorISE(err, "error storing admin in authority cache")
@ -66,7 +66,7 @@ func (a *Authority) UpdateAdmin(ctx context.Context, id string, nu *linkedca.Adm
return nil, admin.WrapErrorISE(err, "error updating cached admin %s", id)
}
if err := a.adminDB.UpdateAdmin(ctx, adm); err != nil {
if err := a.reloadAdminResources(ctx); err != nil {
if err := a.ReloadAdminResources(ctx); err != nil {
return nil, admin.WrapErrorISE(err, "error reloading admin resources on failed admin update")
}
return nil, admin.WrapErrorISE(err, "error updating admin %s", id)
@ -88,7 +88,7 @@ func (a *Authority) removeAdmin(ctx context.Context, id string) error {
return admin.WrapErrorISE(err, "error removing admin %s from authority cache", id)
}
if err := a.adminDB.DeleteAdmin(ctx, id); err != nil {
if err := a.reloadAdminResources(ctx); err != nil {
if err := a.ReloadAdminResources(ctx); err != nil {
return admin.WrapErrorISE(err, "error reloading admin resources on failed admin remove")
}
return admin.WrapErrorISE(err, "error deleting admin %s", id)

@ -13,10 +13,16 @@ import (
"time"
"github.com/pkg/errors"
"golang.org/x/crypto/ssh"
"go.step.sm/crypto/pemutil"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/authority/admin"
adminDBNosql "github.com/smallstep/certificates/authority/admin/db/nosql"
"github.com/smallstep/certificates/authority/administrator"
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/cas"
casapi "github.com/smallstep/certificates/cas/apiv1"
@ -27,9 +33,6 @@ import (
"github.com/smallstep/certificates/scep"
"github.com/smallstep/certificates/templates"
"github.com/smallstep/nosql"
"go.step.sm/crypto/pemutil"
"go.step.sm/linkedca"
"golang.org/x/crypto/ssh"
)
// Authority implements the Certificate Authority internal interface.
@ -81,9 +84,16 @@ type Authority struct {
authorizeRenewFunc provisioner.AuthorizeRenewFunc
authorizeSSHRenewFunc provisioner.AuthorizeSSHRenewFunc
// Policy engines
policyEngine *policy.Engine
adminMutex sync.RWMutex
// Do Not initialize the authority
skipInit bool
}
// Info contains information about the authority.
type Info struct {
StartTime time.Time
RootX509Certs []*x509.Certificate
@ -111,9 +121,11 @@ func New(cfg *config.Config, opts ...Option) (*Authority, error) {
}
}
// Initialize authority from options or configuration.
if err := a.init(); err != nil {
return nil, err
if !a.skipInit {
// Initialize authority from options or configuration.
if err := a.init(); err != nil {
return nil, err
}
}
return a, nil
@ -149,16 +161,41 @@ func NewEmbedded(opts ...Option) (*Authority, error) {
// Initialize config required fields.
a.config.Init()
// Initialize authority from options or configuration.
if err := a.init(); err != nil {
return nil, err
if !a.skipInit {
// Initialize authority from options or configuration.
if err := a.init(); err != nil {
return nil, err
}
}
return a, nil
}
// reloadAdminResources reloads admins and provisioners from the DB.
func (a *Authority) reloadAdminResources(ctx context.Context) error {
type authorityKey struct{}
// NewContext adds the given authority to the context.
func NewContext(ctx context.Context, a *Authority) context.Context {
return context.WithValue(ctx, authorityKey{}, a)
}
// FromContext returns the current authority from the given context.
func FromContext(ctx context.Context) (a *Authority, ok bool) {
a, ok = ctx.Value(authorityKey{}).(*Authority)
return
}
// MustFromContext returns the current authority from the given context. It will
// panic if the authority is not in the context.
func MustFromContext(ctx context.Context) *Authority {
if a, ok := FromContext(ctx); !ok {
panic("authority is not in the context")
} else {
return a
}
}
// ReloadAdminResources reloads admins and provisioners from the DB.
func (a *Authority) ReloadAdminResources(ctx context.Context) error {
var (
provList provisioner.List
adminList []*linkedca.Admin
@ -213,6 +250,7 @@ func (a *Authority) reloadAdminResources(ctx context.Context) error {
a.provisioners = provClxn
a.config.AuthorityConfig.Admins = adminList
a.admins = adminClxn
return nil
}
@ -224,6 +262,7 @@ func (a *Authority) init() error {
}
var err error
ctx := NewContext(context.Background(), a)
// Set password if they are not set.
var configPassword []byte
@ -259,10 +298,25 @@ func (a *Authority) init() error {
if a.config.KMS != nil {
options = *a.config.KMS
}
a.keyManager, err = kms.New(context.Background(), options)
a.keyManager, err = kms.New(ctx, options)
if err != nil {
return err
}
}
// Initialize linkedca client if necessary. On a linked RA, the issuer
// configuration might come from majordomo.
var linkedcaClient *linkedCaClient
if a.config.AuthorityConfig.EnableAdmin && a.linkedCAToken != "" && a.adminDB == nil {
linkedcaClient, err = newLinkedCAClient(a.linkedCAToken)
if err != nil {
return err
}
// If authorityId is configured make sure it matches the one in the token
if id := a.config.AuthorityConfig.AuthorityID; id != "" && !strings.EqualFold(id, linkedcaClient.authorityID) {
return errors.New("error initializing linkedca: token authority and configured authority do not match")
}
linkedcaClient.Run()
}
// Initialize the X.509 CA Service if it has not been set in the options.
@ -272,6 +326,22 @@ func (a *Authority) init() error {
options = *a.config.AuthorityConfig.Options
}
// Configure linked RA
if linkedcaClient != nil && options.CertificateAuthority == "" {
conf, err := linkedcaClient.GetConfiguration(ctx)
if err != nil {
return err
}
if conf.RaConfig != nil {
options.CertificateAuthority = conf.RaConfig.CaUrl
options.CertificateAuthorityFingerprint = conf.RaConfig.Fingerprint
options.CertificateIssuer = &casapi.CertificateIssuer{
Type: conf.RaConfig.Provisioner.Type.String(),
Provisioner: conf.RaConfig.Provisioner.Name,
}
}
}
// Set the issuer password if passed in the flags.
if options.CertificateIssuer != nil && a.issuerPassword != nil {
options.CertificateIssuer.Password = string(a.issuerPassword)
@ -292,7 +362,7 @@ func (a *Authority) init() error {
}
}
a.x509CAService, err = cas.New(context.Background(), options)
a.x509CAService, err = cas.New(ctx, options)
if err != nil {
return err
}
@ -479,7 +549,7 @@ func (a *Authority) init() error {
}
}
a.scepService, err = scep.NewService(context.Background(), options)
a.scepService, err = scep.NewService(ctx, options)
if err != nil {
return err
}
@ -491,40 +561,29 @@ func (a *Authority) init() error {
// Initialize step-ca Admin Database if it's not already initialized using
// WithAdminDB.
if a.adminDB == nil {
if a.linkedCAToken == "" {
// Check if AuthConfig already exists
a.adminDB, err = adminDBNosql.New(a.db.(nosql.DB), admin.DefaultAuthorityID)
if err != nil {
return err
}
if linkedcaClient != nil {
a.adminDB = linkedcaClient
} else {
// Use the linkedca client as the admindb.
client, err := newLinkedCAClient(a.linkedCAToken)
a.adminDB, err = adminDBNosql.New(a.db.(nosql.DB), admin.DefaultAuthorityID)
if err != nil {
return err
}
// If authorityId is configured make sure it matches the one in the token
if id := a.config.AuthorityConfig.AuthorityID; id != "" && !strings.EqualFold(id, client.authorityID) {
return errors.New("error initializing linkedca: token authority and configured authority do not match")
}
client.Run()
a.adminDB = client
}
}
provs, err := a.adminDB.GetProvisioners(context.Background())
provs, err := a.adminDB.GetProvisioners(ctx)
if err != nil {
return admin.WrapErrorISE(err, "error loading provisioners to initialize authority")
}
if len(provs) == 0 && !strings.EqualFold(a.config.AuthorityConfig.DeploymentType, "linked") {
// Create First Provisioner
prov, err := CreateFirstProvisioner(context.Background(), a.adminDB, string(a.password))
prov, err := CreateFirstProvisioner(ctx, a.adminDB, string(a.password))
if err != nil {
return admin.WrapErrorISE(err, "error creating first provisioner")
}
// Create first admin
if err := a.adminDB.CreateAdmin(context.Background(), &linkedca.Admin{
if err := a.adminDB.CreateAdmin(ctx, &linkedca.Admin{
ProvisionerId: prov.Id,
Subject: "step",
Type: linkedca.Admin_SUPER_ADMIN,
@ -535,7 +594,12 @@ func (a *Authority) init() error {
}
// Load Provisioners and Admins
if err := a.reloadAdminResources(context.Background()); err != nil {
if err := a.ReloadAdminResources(ctx); err != nil {
return err
}
// Load x509 and SSH Policy Engines
if err := a.reloadPolicyEngines(ctx); err != nil {
return err
}
@ -570,6 +634,15 @@ func (a *Authority) init() error {
return nil
}
// GetID returns the define authority id or a zero uuid.
func (a *Authority) GetID() string {
const zeroUUID = "00000000-0000-0000-0000-000000000000"
if id := a.config.AuthorityConfig.AuthorityID; id != "" {
return id
}
return zeroUUID
}
// GetDatabase returns the authority database. If the configuration does not
// define a database, GetDatabase will return a db.SimpleDB instance.
func (a *Authority) GetDatabase() db.AuthDB {
@ -581,6 +654,12 @@ func (a *Authority) GetAdminDatabase() admin.DB {
return a.adminDB
}
// GetConfig returns the config.
func (a *Authority) GetConfig() *config.Config {
return a.config
}
// GetInfo returns information about the authority.
func (a *Authority) GetInfo() Info {
ai := Info{
StartTime: a.startTime,

@ -14,6 +14,7 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"go.step.sm/crypto/jose"
@ -421,3 +422,31 @@ func TestAuthority_GetSCEPService(t *testing.T) {
})
}
}
func TestAuthority_GetID(t *testing.T) {
type fields struct {
authorityID string
}
tests := []struct {
name string
fields fields
want string
}{
{"ok", fields{""}, "00000000-0000-0000-0000-000000000000"},
{"ok with id", fields{"10b9a431-ed3b-4a5f-abee-ec35119b65e7"}, "10b9a431-ed3b-4a5f-abee-ec35119b65e7"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Authority{
config: &config.Config{
AuthorityConfig: &config.AuthConfig{
AuthorityID: tt.fields.authorityID,
},
},
}
if got := a.GetID(); got != tt.want {
t.Errorf("Authority.GetID() = %v, want %v", got, tt.want)
}
})
}
}

@ -5,6 +5,7 @@ import (
"crypto/sha256"
"crypto/x509"
"encoding/hex"
"fmt"
"net/http"
"net/url"
"strconv"
@ -41,14 +42,12 @@ func SkipTokenReuseFromContext(ctx context.Context) bool {
return m
}
// authorizeToken parses the token and returns the provisioner used to generate
// the token. This method enforces the One-Time use policy (tokens can only be
// used once).
func (a *Authority) authorizeToken(ctx context.Context, token string) (provisioner.Interface, error) {
// Validate payload
// getProvisionerFromToken extracts a provisioner from the given token without
// doing any token validation.
func (a *Authority) getProvisionerFromToken(token string) (provisioner.Interface, *Claims, error) {
tok, err := jose.ParseSigned(token)
if err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken: error parsing token")
return nil, nil, fmt.Errorf("error parsing token: %w", err)
}
// Get claims w/out verification. We need to look up the provisioner
@ -56,7 +55,25 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
// before we can look up the provisioner.
var claims Claims
if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken")
return nil, nil, fmt.Errorf("error unmarshaling token: %w", err)
}
// This method will also validate the audiences for JWK provisioners.
p, ok := a.provisioners.LoadByToken(tok, &claims.Claims)
if !ok {
return nil, nil, fmt.Errorf("provisioner not found or invalid audience (%s)", strings.Join(claims.Audience, ", "))
}
return p, &claims, nil
}
// authorizeToken parses the token and returns the provisioner used to generate
// the token. This method enforces the One-Time use policy (tokens can only be
// used once).
func (a *Authority) authorizeToken(ctx context.Context, token string) (provisioner.Interface, error) {
p, claims, err := a.getProvisionerFromToken(token)
if err != nil {
return nil, errs.UnauthorizedErr(err)
}
// TODO: use new persistence layer abstraction.
@ -64,17 +81,10 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
// This check is meant as a stopgap solution to the current lack of a persistence layer.
if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck {
if claims.IssuedAt != nil && claims.IssuedAt.Time().Before(a.startTime) {
return nil, errs.Unauthorized("authority.authorizeToken: token issued before the bootstrap of certificate authority")
return nil, errs.Unauthorized("token issued before the bootstrap of certificate authority")
}
}
// This method will also validate the audiences for JWK provisioners.
p, ok := a.provisioners.LoadByToken(tok, &claims.Claims)
if !ok {
return nil, errs.Unauthorized("authority.authorizeToken: provisioner "+
"not found or invalid audience (%s)", strings.Join(claims.Audience, ", "))
}
// Store the token to protect against reuse unless it's skipped.
// If we cannot get a token id from the provisioner, just hash the token.
if !SkipTokenReuseFromContext(ctx) {
@ -130,22 +140,24 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc
// According to "rfc7519 JSON Web Token" acceptable skew should be no
// more than a few minutes.
if err := claims.ValidateWithLeeway(jose.Expected{
Issuer: prov.GetName(),
Time: time.Now().UTC(),
Time: time.Now().UTC(),
}, time.Minute); err != nil {
return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "x5c.authorizeToken; invalid x5c claims")
}
// validate audience: path matches the current path
if r.URL.Path != claims.Audience[0] {
return nil, admin.NewError(admin.ErrorUnauthorizedType,
"x5c.authorizeToken; x5c token has invalid audience "+
"claim (aud); expected %s, but got %s", r.URL.Path, claims.Audience)
if !matchesAudience(claims.Audience, a.config.Audience(r.URL.Path)) {
return nil, admin.NewError(admin.ErrorUnauthorizedType, "x5c.authorizeToken; x5c token has invalid audience claim (aud)")
}
// validate issuer: old versions used the provisioner name, new version uses
// 'step-admin-client/1.0'
if claims.Issuer != "step-admin-client/1.0" && claims.Issuer != prov.GetName() {
return nil, admin.NewError(admin.ErrorUnauthorizedType, "x5c.authorizeToken; x5c token has invalid issuer claim (iss)")
}
if claims.Subject == "" {
return nil, admin.NewError(admin.ErrorUnauthorizedType,
"x5c.authorizeToken; x5c token subject cannot be empty")
return nil, admin.NewError(admin.ErrorUnauthorizedType, "x5c.authorizeToken; x5c token subject cannot be empty")
}
var (
@ -156,7 +168,7 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc
adminSANs := append([]string{leaf.Subject.CommonName}, leaf.DNSNames...)
adminSANs = append(adminSANs, leaf.EmailAddresses...)
for _, san := range adminSANs {
if adm, ok = a.LoadAdminBySubProv(san, claims.Issuer); ok {
if adm, ok = a.LoadAdminBySubProv(san, prov.GetName()); ok {
adminFound = true
break
}
@ -186,11 +198,10 @@ func (a *Authority) UseToken(token string, prov provisioner.Interface) error {
}
ok, err := a.db.UseToken(reuseKey, token)
if err != nil {
return errs.Wrap(http.StatusInternalServerError, err,
"authority.authorizeToken: failed when attempting to store token")
return errs.Wrap(http.StatusInternalServerError, err, "failed when attempting to store token")
}
if !ok {
return errs.Unauthorized("authority.authorizeToken: token already used")
return errs.Unauthorized("token already used")
}
}
return nil
@ -249,10 +260,10 @@ func (a *Authority) authorizeSign(ctx context.Context, token string) ([]provisio
// AuthorizeSign authorizes a signature request by validating and authenticating
// a token that must be sent w/ the request.
//
// NOTE: This method is deprecated and should not be used. We make it available
// in the short term os as not to break existing clients.
// Deprecated: Use Authorize(context.Context, string) ([]provisioner.SignOption, error).
func (a *Authority) AuthorizeSign(token string) ([]provisioner.SignOption, error) {
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
ctx := NewContext(context.Background(), a)
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
return a.Authorize(ctx, token)
}
@ -285,9 +296,16 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
if isRevoked {
return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...)
}
p, ok := a.provisioners.LoadByCertificate(cert)
if !ok {
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
p, err := a.LoadProvisionerByCertificate(cert)
if err != nil {
var ok bool
// For backward compatibility this method will also succeed if the
// certificate does not have a provisioner extension. LoadByCertificate
// returns the noop provisioner if this happens, and it allows
// certificate renewals.
if p, ok = a.provisioners.LoadByCertificate(cert); !ok {
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
}
}
if err := p.AuthorizeRenew(context.Background(), cert); err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
@ -386,8 +404,8 @@ func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.
return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token"))
}
p, ok := a.provisioners.LoadByCertificate(leaf)
if !ok {
p, err := a.LoadProvisionerByCertificate(leaf)
if err != nil {
return nil, errs.Unauthorized("error validating renew token: cannot get provisioner from certificate")
}
if err := a.UseToken(ott, p); err != nil {
@ -395,7 +413,6 @@ func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.
}
if err := claims.ValidateWithLeeway(jose.Expected{
Issuer: p.GetName(),
Subject: leaf.Subject.CommonName,
Time: time.Now().UTC(),
}, time.Minute); err != nil {
@ -420,6 +437,12 @@ func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.
return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token: invalid audience claim (aud)"))
}
// validate issuer: old versions used the provisioner name, new version uses
// 'step-ca-client/1.0'
if claims.Issuer != "step-ca-client/1.0" && claims.Issuer != p.GetName() {
return nil, admin.NewError(admin.ErrorUnauthorizedType, "error validating renew token: invalid issuer claim (iss)")
}
return leaf, nil
}

@ -114,7 +114,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
return &authorizeTest{
auth: a,
token: "foo",
err: errors.New("authority.authorizeToken: error parsing token"),
err: errors.New("error parsing token"),
code: http.StatusUnauthorized,
}
},
@ -133,7 +133,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
return &authorizeTest{
auth: a,
token: raw,
err: errors.New("authority.authorizeToken: token issued before the bootstrap of certificate authority"),
err: errors.New("token issued before the bootstrap of certificate authority"),
code: http.StatusUnauthorized,
}
},
@ -155,7 +155,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
return &authorizeTest{
auth: a,
token: raw,
err: errors.New("authority.authorizeToken: provisioner not found or invalid audience (https://example.com/revoke)"),
err: errors.New("provisioner not found or invalid audience (https://example.com/revoke)"),
code: http.StatusUnauthorized,
}
},
@ -192,7 +192,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
return &authorizeTest{
auth: _a,
token: raw,
err: errors.New("authority.authorizeToken: token already used"),
err: errors.New("token already used"),
code: http.StatusUnauthorized,
}
},
@ -227,7 +227,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
return &authorizeTest{
auth: _a,
token: raw,
err: errors.New("authority.authorizeToken: token already used"),
err: errors.New("token already used"),
code: http.StatusUnauthorized,
}
},
@ -275,7 +275,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
return &authorizeTest{
auth: _a,
token: raw,
err: errors.New("authority.authorizeToken: failed when attempting to store token: force"),
err: errors.New("failed when attempting to store token: force"),
code: http.StatusInternalServerError,
}
},
@ -300,7 +300,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
return &authorizeTest{
auth: _a,
token: raw,
err: errors.New("authority.authorizeToken: token already used"),
err: errors.New("token already used"),
code: http.StatusUnauthorized,
}
},
@ -353,7 +353,7 @@ func TestAuthority_authorizeRevoke(t *testing.T) {
return &authorizeTest{
auth: a,
token: "foo",
err: errors.New("authority.authorizeRevoke: authority.authorizeToken: error parsing token"),
err: errors.New("authority.authorizeRevoke: error parsing token"),
code: http.StatusUnauthorized,
}
},
@ -437,7 +437,7 @@ func TestAuthority_authorizeSign(t *testing.T) {
return &authorizeTest{
auth: a,
token: "foo",
err: errors.New("authority.authorizeSign: authority.authorizeToken: error parsing token"),
err: errors.New("authority.authorizeSign: error parsing token"),
code: http.StatusUnauthorized,
}
},
@ -491,7 +491,7 @@ func TestAuthority_authorizeSign(t *testing.T) {
}
} else {
if assert.Nil(t, tc.err) {
assert.Len(t, 7, got)
assert.Equals(t, 9, len(got)) // number of provisioner.SignOptions returned
}
}
})
@ -524,7 +524,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a,
token: "foo",
ctx: context.Background(),
err: errors.New("authority.Authorize: authority.authorizeSign: authority.authorizeToken: error parsing token"),
err: errors.New("authority.Authorize: authority.authorizeSign: error parsing token"),
code: http.StatusUnauthorized,
}
},
@ -533,7 +533,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a,
token: "foo",
ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod),
err: errors.New("authority.Authorize: authority.authorizeSign: authority.authorizeToken: error parsing token"),
err: errors.New("authority.Authorize: authority.authorizeSign: error parsing token"),
code: http.StatusUnauthorized,
}
},
@ -559,7 +559,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a,
token: "foo",
ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod),
err: errors.New("authority.Authorize: authority.authorizeRevoke: authority.authorizeToken: error parsing token"),
err: errors.New("authority.Authorize: authority.authorizeRevoke: error parsing token"),
code: http.StatusUnauthorized,
}
},
@ -585,7 +585,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a,
token: "foo",
ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod),
err: errors.New("authority.Authorize: authority.authorizeSSHSign: authority.authorizeToken: error parsing token"),
err: errors.New("authority.Authorize: authority.authorizeSSHSign: error parsing token"),
code: http.StatusUnauthorized,
}
},
@ -615,7 +615,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a,
token: "foo",
ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod),
err: errors.New("authority.Authorize: authority.authorizeSSHRenew: authority.authorizeToken: error parsing token"),
err: errors.New("authority.Authorize: authority.authorizeSSHRenew: error parsing token"),
code: http.StatusUnauthorized,
}
},
@ -659,7 +659,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a,
token: "foo",
ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod),
err: errors.New("authority.Authorize: authority.authorizeSSHRevoke: authority.authorizeToken: error parsing token"),
err: errors.New("authority.Authorize: authority.authorizeSSHRevoke: error parsing token"),
code: http.StatusUnauthorized,
}
},
@ -685,7 +685,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a,
token: "foo",
ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod),
err: errors.New("authority.Authorize: authority.authorizeSSHRekey: authority.authorizeToken: error parsing token"),
err: errors.New("authority.Authorize: authority.authorizeSSHRekey: error parsing token"),
code: http.StatusUnauthorized,
}
},
@ -847,6 +847,29 @@ func TestAuthority_authorizeRenew(t *testing.T) {
cert: fooCrt,
}
},
"ok/from db": func(t *testing.T) *authorizeTest {
a := testAuthority(t)
a.db = &db.MockAuthDB{
MIsRevoked: func(key string) (bool, error) {
return false, nil
},
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
p, ok := a.provisioners.LoadByName("step-cli")
if !ok {
t.Fatal("provisioner step-cli not found")
}
return &db.CertificateData{
Provisioner: &db.ProvisionerData{
ID: p.GetID(),
},
}, nil
},
}
return &authorizeTest{
auth: a,
cert: fooCrt,
}
},
}
for name, genTestCase := range tests {
@ -965,7 +988,7 @@ func TestAuthority_authorizeSSHSign(t *testing.T) {
return &authorizeTest{
auth: a,
token: "foo",
err: errors.New("authority.authorizeSSHSign: authority.authorizeToken: error parsing token"),
err: errors.New("authority.authorizeSSHSign: error parsing token"),
code: http.StatusUnauthorized,
}
},
@ -1011,7 +1034,7 @@ func TestAuthority_authorizeSSHSign(t *testing.T) {
}
} else {
if assert.Nil(t, tc.err) {
assert.Len(t, 7, got)
assert.Len(t, 9, got) // number of provisioner.SignOptions returned
}
}
})
@ -1059,7 +1082,7 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) {
return &authorizeTest{
auth: a,
token: "foo",
err: errors.New("authority.authorizeSSHRenew: authority.authorizeToken: error parsing token"),
err: errors.New("authority.authorizeSSHRenew: error parsing token"),
code: http.StatusUnauthorized,
}
},
@ -1167,7 +1190,7 @@ func TestAuthority_authorizeSSHRevoke(t *testing.T) {
return &authorizeTest{
auth: a,
token: "foo",
err: errors.New("authority.authorizeSSHRevoke: authority.authorizeToken: error parsing token"),
err: errors.New("authority.authorizeSSHRevoke: error parsing token"),
code: http.StatusUnauthorized,
}
},
@ -1259,7 +1282,7 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) {
return &authorizeTest{
auth: a,
token: "foo",
err: errors.New("authority.authorizeSSHRekey: authority.authorizeToken: error parsing token"),
err: errors.New("authority.authorizeSSHRekey: error parsing token"),
code: http.StatusUnauthorized,
}
},
@ -1322,7 +1345,7 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) {
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, tc.cert.Serial, cert.Serial)
assert.Len(t, 3, signOpts)
assert.Len(t, 4, signOpts)
}
}
})
@ -1381,7 +1404,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
t1, c1 := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-cli",
Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
@ -1400,7 +1423,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
t2, c2 := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-cli",
Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
IssuedAt: jose.NewNumericDate(now),
@ -1417,12 +1440,31 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
})
return nil
}))
badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{
t3, c3 := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
@ -1439,7 +1481,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
badProvisioner, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-cli",
Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
@ -1477,7 +1519,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
badSubject, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "bad-subject",
Issuer: "step-cli",
Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
@ -1496,7 +1538,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
badNotBefore, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com",
Issuer: "step-cli",
Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now.Add(5 * time.Minute)),
Expiry: jose.NewNumericDate(now.Add(10 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
@ -1515,7 +1557,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
badExpiry, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com",
Issuer: "step-cli",
Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now.Add(-5 * time.Minute)),
Expiry: jose.NewNumericDate(now.Add(-time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
@ -1534,7 +1576,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
badIssuedAt, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com",
Issuer: "step-cli",
Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
IssuedAt: jose.NewNumericDate(now.Add(5 * time.Minute)),
@ -1554,7 +1596,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
badAudience, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com",
Issuer: "step-cli",
Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
@ -1584,6 +1626,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
}{
{"ok", a1, args{ctx, t1}, c1, false},
{"ok expired cert", a1, args{ctx, t2}, c2, false},
{"ok provisioner issuer", a1, args{ctx, t3}, c3, false},
{"fail token", a1, args{ctx, "not.a.token"}, nil, true},
{"fail token reuse", a1, args{ctx, t1}, nil, true},
{"fail token signature", a1, args{ctx, badSigner}, nil, true},

@ -8,12 +8,15 @@ import (
"time"
"github.com/pkg/errors"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/authority/provisioner"
cas "github.com/smallstep/certificates/cas/apiv1"
"github.com/smallstep/certificates/db"
kms "github.com/smallstep/certificates/kms/apiv1"
"github.com/smallstep/certificates/templates"
"go.step.sm/linkedca"
)
const (
@ -26,27 +29,27 @@ var (
DefaultBackdate = time.Minute
// DefaultDisableRenewal disables renewals per provisioner.
DefaultDisableRenewal = false
// DefaultAllowRenewAfterExpiry allows renewals even if the certificate is
// DefaultAllowRenewalAfterExpiry allows renewals even if the certificate is
// expired.
DefaultAllowRenewAfterExpiry = false
DefaultAllowRenewalAfterExpiry = false
// DefaultEnableSSHCA enable SSH CA features per provisioner or globally
// for all provisioners.
DefaultEnableSSHCA = false
// GlobalProvisionerClaims default claims for the Authority. Can be overridden
// by provisioner specific claims.
GlobalProvisionerClaims = provisioner.Claims{
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs
MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour},
MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs
MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
EnableSSHCA: &DefaultEnableSSHCA,
DisableRenewal: &DefaultDisableRenewal,
AllowRenewAfterExpiry: &DefaultAllowRenewAfterExpiry,
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs
MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour},
MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs
MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
EnableSSHCA: &DefaultEnableSSHCA,
DisableRenewal: &DefaultDisableRenewal,
AllowRenewalAfterExpiry: &DefaultAllowRenewalAfterExpiry,
}
)
@ -68,6 +71,7 @@ type Config struct {
TLS *TLSOptions `json:"tls,omitempty"`
Password string `json:"password,omitempty"`
Templates *templates.Templates `json:"templates,omitempty"`
CommonName string `json:"commonName,omitempty"`
CRL *CRLConfig `json:"crl,omitempty"`
}
@ -95,6 +99,7 @@ type AuthConfig struct {
Admins []*linkedca.Admin `json:"-"`
Template *ASN1DN `json:"template,omitempty"`
Claims *provisioner.Claims `json:"claims,omitempty"`
Policy *policy.Options `json:"policy,omitempty"`
DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"`
Backdate *provisioner.Duration `json:"backdate,omitempty"`
EnableAdmin bool `json:"enableAdmin,omitempty"`
@ -180,6 +185,9 @@ func (c *Config) Init() {
if c.AuthorityConfig == nil {
c.AuthorityConfig = &AuthConfig{}
}
if c.CommonName == "" {
c.CommonName = "Step Online CA"
}
c.AuthorityConfig.init()
}
@ -311,6 +319,18 @@ func (c *Config) GetAudiences() provisioner.Audiences {
return audiences
}
// Audience returns the list of audiences for a given path.
func (c *Config) Audience(path string) []string {
audiences := make([]string, len(c.DNSNames)+1)
for i, name := range c.DNSNames {
hostname := toHostname(name)
audiences[i] = "https://" + hostname + path
}
// For backward compatibility
audiences[len(c.DNSNames)] = path
return audiences
}
func toHostname(name string) string {
// ensure an IPv6 address is represented with square brackets when used as hostname
if ip := net.ParseIP(name); ip != nil && ip.To4() == nil {

@ -2,6 +2,7 @@ package config
import (
"fmt"
"reflect"
"testing"
"github.com/pkg/errors"
@ -317,3 +318,38 @@ func Test_toHostname(t *testing.T) {
})
}
}
func TestConfig_Audience(t *testing.T) {
type fields struct {
DNSNames []string
}
type args struct {
path string
}
tests := []struct {
name string
fields fields
args args
want []string
}{
{"ok", fields{[]string{
"ca", "ca.example.com", "127.0.0.1", "::1",
}}, args{"/path"}, []string{
"https://ca/path",
"https://ca.example.com/path",
"https://127.0.0.1/path",
"https://[::1]/path",
"/path",
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Config{
DNSNames: tt.fields.DNSNames,
}
if got := c.Audience(tt.args.path); !reflect.DeepEqual(got, tt.want) {
t.Errorf("Config.Audience() = %v, want %v", got, tt.want)
}
})
}
}

@ -15,15 +15,19 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/db"
"golang.org/x/crypto/ssh"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/tlsutil"
"go.step.sm/crypto/x509util"
"go.step.sm/linkedca"
"golang.org/x/crypto/ssh"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
)
const uuidPattern = "^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}$"
@ -34,6 +38,9 @@ type linkedCaClient struct {
authorityID string
}
// interface guard
var _ admin.DB = (*linkedCaClient)(nil)
type linkedCAClaims struct {
jose.Claims
SANs []string `json:"sans"`
@ -115,6 +122,13 @@ func newLinkedCAClient(token string) (*linkedCaClient, error) {
}, nil
}
// IsLinkedCA is a sentinel function that can be used to
// check if a linkedCaClient is the underlying type of an
// admin.DB interface.
func (c *linkedCaClient) IsLinkedCA() bool {
return true
}
func (c *linkedCaClient) Run() {
c.renewer.Run()
}
@ -151,13 +165,21 @@ func (c *linkedCaClient) GetProvisioner(ctx context.Context, id string) (*linked
}
func (c *linkedCaClient) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) {
resp, err := c.GetConfiguration(ctx)
if err != nil {
return nil, err
}
return resp.Provisioners, nil
}
func (c *linkedCaClient) GetConfiguration(ctx context.Context) (*linkedca.ConfigurationResponse, error) {
resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{
AuthorityId: c.authorityID,
})
if err != nil {
return nil, errors.Wrap(err, "error getting provisioners")
return nil, errors.Wrap(err, "error getting configuration")
}
return resp.Provisioners, nil
return resp, nil
}
func (c *linkedCaClient) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error {
@ -204,11 +226,9 @@ func (c *linkedCaClient) GetAdmin(ctx context.Context, id string) (*linkedca.Adm
}
func (c *linkedCaClient) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) {
resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{
AuthorityId: c.authorityID,
})
resp, err := c.GetConfiguration(ctx)
if err != nil {
return nil, errors.Wrap(err, "error getting admins")
return nil, err
}
return resp.Admins, nil
}
@ -228,12 +248,35 @@ func (c *linkedCaClient) DeleteAdmin(ctx context.Context, id string) error {
return errors.Wrap(err, "error deleting admin")
}
func (c *linkedCaClient) StoreCertificateChain(fullchain ...*x509.Certificate) error {
func (c *linkedCaClient) GetCertificateData(serial string) (*db.CertificateData, error) {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
resp, err := c.client.GetCertificate(ctx, &linkedca.GetCertificateRequest{
Serial: serial,
})
if err != nil {
return nil, err
}
var pd *db.ProvisionerData
if p := resp.Provisioner; p != nil {
pd = &db.ProvisionerData{
ID: p.Id, Name: p.Name, Type: p.Type.String(),
}
}
return &db.CertificateData{
Provisioner: pd,
}, nil
}
func (c *linkedCaClient) StoreCertificateChain(p provisioner.Interface, fullchain ...*x509.Certificate) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
_, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{
PemCertificate: serializeCertificateChain(fullchain[0]),
PemCertificateChain: serializeCertificateChain(fullchain[1:]...),
Provisioner: createProvisionerIdentity(p),
})
return errors.Wrap(err, "error posting certificate")
}
@ -246,18 +289,30 @@ func (c *linkedCaClient) StoreRenewedCertificate(parent *x509.Certificate, fullc
PemCertificateChain: serializeCertificateChain(fullchain[1:]...),
PemParentCertificate: serializeCertificateChain(parent),
})
return errors.Wrap(err, "error posting certificate")
return errors.Wrap(err, "error posting renewed certificate")
}
func (c *linkedCaClient) StoreSSHCertificate(crt *ssh.Certificate) error {
func (c *linkedCaClient) StoreSSHCertificate(p provisioner.Interface, crt *ssh.Certificate) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
_, err := c.client.PostSSHCertificate(ctx, &linkedca.SSHCertificateRequest{
Certificate: string(ssh.MarshalAuthorizedKey(crt)),
Provisioner: createProvisionerIdentity(p),
})
return errors.Wrap(err, "error posting ssh certificate")
}
func (c *linkedCaClient) StoreRenewedSSHCertificate(p provisioner.Interface, parent, crt *ssh.Certificate) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
_, err := c.client.PostSSHCertificate(ctx, &linkedca.SSHCertificateRequest{
Certificate: string(ssh.MarshalAuthorizedKey(crt)),
ParentCertificate: string(ssh.MarshalAuthorizedKey(parent)),
Provisioner: createProvisionerIdentity(p),
})
return errors.Wrap(err, "error posting renewed ssh certificate")
}
func (c *linkedCaClient) Revoke(crt *x509.Certificate, rci *db.RevokedCertificateInfo) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
@ -310,6 +365,33 @@ func (c *linkedCaClient) IsSSHRevoked(serial string) (bool, error) {
return resp.Status != linkedca.RevocationStatus_ACTIVE, nil
}
func (c *linkedCaClient) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
return errors.New("not implemented yet")
}
func (c *linkedCaClient) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
return nil, errors.New("not implemented yet")
}
func (c *linkedCaClient) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
return errors.New("not implemented yet")
}
func (c *linkedCaClient) DeleteAuthorityPolicy(ctx context.Context) error {
return errors.New("not implemented yet")
}
func createProvisionerIdentity(p provisioner.Interface) *linkedca.ProvisionerIdentity {
if p == nil {
return nil
}
return &linkedca.ProvisionerIdentity{
Id: p.GetID(),
Type: linkedca.Provisioner_Type(p.GetType()),
Name: p.GetName(),
}
}
func serializeCertificate(crt *x509.Certificate) string {
if crt == nil {
return ""

@ -266,6 +266,16 @@ func WithAdminDB(d admin.DB) Option {
}
}
// WithProvisioners is an option to set the provisioner collection.
//
// Deprecated: provisioner collections will likely change
func WithProvisioners(ps *provisioner.Collection) Option {
return func(a *Authority) error {
a.provisioners = ps
return nil
}
}
// WithLinkedCAToken is an option to set the authentication token used to enable
// linked ca.
func WithLinkedCAToken(token string) Option {
@ -284,6 +294,15 @@ func WithX509Enforcers(ces ...provisioner.CertificateEnforcer) Option {
}
}
// WithSkipInit is an option that allows the constructor to skip initializtion
// of the authority.
func WithSkipInit() Option {
return func(a *Authority) error {
a.skipInit = true
return nil
}
}
func readCertificateBundle(pemCerts []byte) ([]*x509.Certificate, error) {
var block *pem.Block
var certs []*x509.Certificate

@ -0,0 +1,265 @@
package authority
import (
"context"
"errors"
"fmt"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/authority/admin"
authPolicy "github.com/smallstep/certificates/authority/policy"
policy "github.com/smallstep/certificates/policy"
)
type policyErrorType int
const (
AdminLockOut policyErrorType = iota + 1
StoreFailure
ReloadFailure
ConfigurationFailure
EvaluationFailure
InternalFailure
)
type PolicyError struct {
Typ policyErrorType
Err error
}
func (p *PolicyError) Error() string {
return p.Err.Error()
}
func (a *Authority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
a.adminMutex.Lock()
defer a.adminMutex.Unlock()
p, err := a.adminDB.GetAuthorityPolicy(ctx)
if err != nil {
return nil, &PolicyError{
Typ: InternalFailure,
Err: err,
}
}
return p, nil
}
func (a *Authority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, p *linkedca.Policy) (*linkedca.Policy, error) {
a.adminMutex.Lock()
defer a.adminMutex.Unlock()
if err := a.checkAuthorityPolicy(ctx, adm, p); err != nil {
return nil, err
}
if err := a.adminDB.CreateAuthorityPolicy(ctx, p); err != nil {
return nil, &PolicyError{
Typ: StoreFailure,
Err: err,
}
}
if err := a.reloadPolicyEngines(ctx); err != nil {
return nil, &PolicyError{
Typ: ReloadFailure,
Err: fmt.Errorf("error reloading policy engines when creating authority policy: %w", err),
}
}
return p, nil
}
func (a *Authority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, p *linkedca.Policy) (*linkedca.Policy, error) {
a.adminMutex.Lock()
defer a.adminMutex.Unlock()
if err := a.checkAuthorityPolicy(ctx, adm, p); err != nil {
return nil, err
}
if err := a.adminDB.UpdateAuthorityPolicy(ctx, p); err != nil {
return nil, &PolicyError{
Typ: StoreFailure,
Err: err,
}
}
if err := a.reloadPolicyEngines(ctx); err != nil {
return nil, &PolicyError{
Typ: ReloadFailure,
Err: fmt.Errorf("error reloading policy engines when updating authority policy: %w", err),
}
}
return p, nil
}
func (a *Authority) RemoveAuthorityPolicy(ctx context.Context) error {
a.adminMutex.Lock()
defer a.adminMutex.Unlock()
if err := a.adminDB.DeleteAuthorityPolicy(ctx); err != nil {
return &PolicyError{
Typ: StoreFailure,
Err: err,
}
}
if err := a.reloadPolicyEngines(ctx); err != nil {
return &PolicyError{
Typ: ReloadFailure,
Err: fmt.Errorf("error reloading policy engines when deleting authority policy: %w", err),
}
}
return nil
}
func (a *Authority) checkAuthorityPolicy(ctx context.Context, currentAdmin *linkedca.Admin, p *linkedca.Policy) error {
// no policy and thus nothing to evaluate; return early
if p == nil {
return nil
}
// get all current admins from the database
allAdmins, err := a.adminDB.GetAdmins(ctx)
if err != nil {
return &PolicyError{
Typ: InternalFailure,
Err: fmt.Errorf("error retrieving admins: %w", err),
}
}
return a.checkPolicy(ctx, currentAdmin, allAdmins, p)
}
func (a *Authority) checkProvisionerPolicy(ctx context.Context, provName string, p *linkedca.Policy) error {
// no policy and thus nothing to evaluate; return early
if p == nil {
return nil
}
// get all admins for the provisioner; ignoring case in which they're not found
allProvisionerAdmins, _ := a.admins.LoadByProvisioner(provName)
// check the policy; pass in nil as the current admin, as all admins for the
// provisioner will be checked by looping through allProvisionerAdmins. Also,
// the current admin may be a super admin not belonging to the provisioner, so
// can't be blocked, but is not required to be in the policy, either.
return a.checkPolicy(ctx, nil, allProvisionerAdmins, p)
}
// checkPolicy checks if a new or updated policy configuration results in the user
// locking themselves or other admins out of the CA.
func (a *Authority) checkPolicy(ctx context.Context, currentAdmin *linkedca.Admin, otherAdmins []*linkedca.Admin, p *linkedca.Policy) error {
// convert the policy; return early if nil
policyOptions := authPolicy.LinkedToCertificates(p)
if policyOptions == nil {
return nil
}
engine, err := authPolicy.NewX509PolicyEngine(policyOptions.GetX509Options())
if err != nil {
return &PolicyError{
Typ: ConfigurationFailure,
Err: err,
}
}
// when an empty X.509 policy is provided, the resulting engine is nil
// and there's no policy to evaluate.
if engine == nil {
return nil
}
// TODO(hs): Provide option to force the policy, even when the admin subject would be locked out?
// check if the admin user that instructed the authority policy to be
// created or updated, would still be allowed when the provided policy
// would be applied. This case is skipped when current admin is nil, which
// is the case when a provisioner policy is checked.
if currentAdmin != nil {
sans := []string{currentAdmin.GetSubject()}
if err := isAllowed(engine, sans); err != nil {
return err
}
}
// loop through admins to verify that none of them would be
// locked out when the new policy were to be applied. Returns
// an error with a message that includes the admin subject that
// would be locked out.
for _, adm := range otherAdmins {
sans := []string{adm.GetSubject()}
if err := isAllowed(engine, sans); err != nil {
return err
}
}
// TODO(hs): mask the error message for non-super admins?
return nil
}
// reloadPolicyEngines reloads x509 and SSH policy engines using
// configuration stored in the DB or from the configuration file.
func (a *Authority) reloadPolicyEngines(ctx context.Context) error {
var (
err error
policyOptions *authPolicy.Options
)
if a.config.AuthorityConfig.EnableAdmin {
// temporarily disable policy loading when LinkedCA is in use
if _, ok := a.adminDB.(*linkedCaClient); ok {
return nil
}
linkedPolicy, err := a.adminDB.GetAuthorityPolicy(ctx)
if err != nil {
var ae *admin.Error
if isAdminError := errors.As(err, &ae); (isAdminError && ae.Type != admin.ErrorNotFoundType.String()) || !isAdminError {
return fmt.Errorf("error getting policy to (re)load policy engines: %w", err)
}
}
policyOptions = authPolicy.LinkedToCertificates(linkedPolicy)
} else {
policyOptions = a.config.AuthorityConfig.Policy
}
engine, err := authPolicy.New(policyOptions)
if err != nil {
return err
}
// only update the policy engine when no error was returned
a.policyEngine = engine
return nil
}
func isAllowed(engine authPolicy.X509Policy, sans []string) error {
if err := engine.AreSANsAllowed(sans); err != nil {
var policyErr *policy.NamePolicyError
isNamePolicyError := errors.As(err, &policyErr)
if isNamePolicyError && policyErr.Reason == policy.NotAllowed {
return &PolicyError{
Typ: AdminLockOut,
Err: fmt.Errorf("the provided policy would lock out %s from the CA. Please update your policy to include %s as an allowed name", sans, sans),
}
}
return &PolicyError{
Typ: EvaluationFailure,
Err: err,
}
}
return nil
}

@ -0,0 +1,114 @@
package policy
import (
"crypto/x509"
"errors"
"fmt"
"golang.org/x/crypto/ssh"
)
// Engine is a container for multiple policies.
type Engine struct {
x509Policy X509Policy
sshUserPolicy UserPolicy
sshHostPolicy HostPolicy
}
// New returns a new Engine using Options.
func New(options *Options) (*Engine, error) {
// if no options provided, return early
if options == nil {
return nil, nil
}
var (
x509Policy X509Policy
sshHostPolicy HostPolicy
sshUserPolicy UserPolicy
err error
)
// initialize the x509 allow/deny policy engine
if x509Policy, err = NewX509PolicyEngine(options.GetX509Options()); err != nil {
return nil, err
}
// initialize the SSH allow/deny policy engine for host certificates
if sshHostPolicy, err = NewSSHHostPolicyEngine(options.GetSSHOptions()); err != nil {
return nil, err
}
// initialize the SSH allow/deny policy engine for user certificates
if sshUserPolicy, err = NewSSHUserPolicyEngine(options.GetSSHOptions()); err != nil {
return nil, err
}
return &Engine{
x509Policy: x509Policy,
sshHostPolicy: sshHostPolicy,
sshUserPolicy: sshUserPolicy,
}, nil
}
// IsX509CertificateAllowed evaluates an X.509 certificate against
// the X.509 policy (if available) and returns an error if one of the
// names in the certificate is not allowed.
func (e *Engine) IsX509CertificateAllowed(cert *x509.Certificate) error {
// return early if there's no policy to evaluate
if e == nil || e.x509Policy == nil {
return nil
}
// return result of X.509 policy evaluation
return e.x509Policy.IsX509CertificateAllowed(cert)
}
// AreSANsAllowed evaluates the slice of SANs against the X.509 policy
// (if available) and returns an error if one of the SANs is not allowed.
func (e *Engine) AreSANsAllowed(sans []string) error {
// return early if there's no policy to evaluate
if e == nil || e.x509Policy == nil {
return nil
}
// return result of X.509 policy evaluation
return e.x509Policy.AreSANsAllowed(sans)
}
// IsSSHCertificateAllowed evaluates an SSH certificate against the
// user or host policy (if configured) and returns an error if one of the
// principals in the certificate is not allowed.
func (e *Engine) IsSSHCertificateAllowed(cert *ssh.Certificate) error {
// return early if there's no policy to evaluate
if e == nil || (e.sshHostPolicy == nil && e.sshUserPolicy == nil) {
return nil
}
switch cert.CertType {
case ssh.HostCert:
// when no host policy engine is configured, but a user policy engine is
// configured, the host certificate is denied.
if e.sshHostPolicy == nil && e.sshUserPolicy != nil {
return errors.New("authority not allowed to sign ssh host certificates")
}
// return result of SSH host policy evaluation
return e.sshHostPolicy.IsSSHCertificateAllowed(cert)
case ssh.UserCert:
// when no user policy engine is configured, but a host policy engine is
// configured, the user certificate is denied.
if e.sshUserPolicy == nil && e.sshHostPolicy != nil {
return errors.New("authority not allowed to sign ssh user certificates")
}
// return result of SSH user policy evaluation
return e.sshUserPolicy.IsSSHCertificateAllowed(cert)
default:
return fmt.Errorf("unexpected ssh certificate type %q", cert.CertType)
}
}

@ -0,0 +1,194 @@
package policy
// Options is a container for authority level x509 and SSH
// policy configuration.
type Options struct {
X509 *X509PolicyOptions `json:"x509,omitempty"`
SSH *SSHPolicyOptions `json:"ssh,omitempty"`
}
// GetX509Options returns the x509 authority level policy
// configuration
func (o *Options) GetX509Options() *X509PolicyOptions {
if o == nil {
return nil
}
return o.X509
}
// GetSSHOptions returns the SSH authority level policy
// configuration
func (o *Options) GetSSHOptions() *SSHPolicyOptions {
if o == nil {
return nil
}
return o.SSH
}
// X509PolicyOptionsInterface is an interface for providers
// of x509 allowed and denied names.
type X509PolicyOptionsInterface interface {
GetAllowedNameOptions() *X509NameOptions
GetDeniedNameOptions() *X509NameOptions
AreWildcardNamesAllowed() bool
}
// X509PolicyOptions is a container for x509 allowed and denied
// names.
type X509PolicyOptions struct {
// AllowedNames contains the x509 allowed names
AllowedNames *X509NameOptions `json:"allow,omitempty"`
// DeniedNames contains the x509 denied names
DeniedNames *X509NameOptions `json:"deny,omitempty"`
// AllowWildcardNames indicates if literal wildcard names
// like *.example.com are allowed. Defaults to false.
AllowWildcardNames bool `json:"allowWildcardNames,omitempty"`
}
// X509NameOptions models the X509 name policy configuration.
type X509NameOptions struct {
CommonNames []string `json:"cn,omitempty"`
DNSDomains []string `json:"dns,omitempty"`
IPRanges []string `json:"ip,omitempty"`
EmailAddresses []string `json:"email,omitempty"`
URIDomains []string `json:"uri,omitempty"`
}
// HasNames checks if the AllowedNameOptions has one or more
// names configured.
func (o *X509NameOptions) HasNames() bool {
return len(o.CommonNames) > 0 ||
len(o.DNSDomains) > 0 ||
len(o.IPRanges) > 0 ||
len(o.EmailAddresses) > 0 ||
len(o.URIDomains) > 0
}
// GetAllowedNameOptions returns x509 allowed name policy configuration
func (o *X509PolicyOptions) GetAllowedNameOptions() *X509NameOptions {
if o == nil {
return nil
}
return o.AllowedNames
}
// GetDeniedNameOptions returns the x509 denied name policy configuration
func (o *X509PolicyOptions) GetDeniedNameOptions() *X509NameOptions {
if o == nil {
return nil
}
return o.DeniedNames
}
// AreWildcardNamesAllowed returns whether the authority allows
// literal wildcard names to be signed.
func (o *X509PolicyOptions) AreWildcardNamesAllowed() bool {
if o == nil {
return true
}
return o.AllowWildcardNames
}
// SSHPolicyOptionsInterface is an interface for providers of
// SSH user and host name policy configuration.
type SSHPolicyOptionsInterface interface {
GetAllowedUserNameOptions() *SSHNameOptions
GetDeniedUserNameOptions() *SSHNameOptions
GetAllowedHostNameOptions() *SSHNameOptions
GetDeniedHostNameOptions() *SSHNameOptions
}
// SSHPolicyOptions is a container for SSH user and host policy
// configuration
type SSHPolicyOptions struct {
// User contains SSH user certificate options.
User *SSHUserCertificateOptions `json:"user,omitempty"`
// Host contains SSH host certificate options.
Host *SSHHostCertificateOptions `json:"host,omitempty"`
}
// GetAllowedUserNameOptions returns the SSH allowed user name policy
// configuration.
func (o *SSHPolicyOptions) GetAllowedUserNameOptions() *SSHNameOptions {
if o == nil || o.User == nil {
return nil
}
return o.User.AllowedNames
}
// GetDeniedUserNameOptions returns the SSH denied user name policy
// configuration.
func (o *SSHPolicyOptions) GetDeniedUserNameOptions() *SSHNameOptions {
if o == nil || o.User == nil {
return nil
}
return o.User.DeniedNames
}
// GetAllowedHostNameOptions returns the SSH allowed host name policy
// configuration.
func (o *SSHPolicyOptions) GetAllowedHostNameOptions() *SSHNameOptions {
if o == nil || o.Host == nil {
return nil
}
return o.Host.AllowedNames
}
// GetDeniedHostNameOptions returns the SSH denied host name policy
// configuration.
func (o *SSHPolicyOptions) GetDeniedHostNameOptions() *SSHNameOptions {
if o == nil || o.Host == nil {
return nil
}
return o.Host.DeniedNames
}
// SSHUserCertificateOptions is a collection of SSH user certificate options.
type SSHUserCertificateOptions struct {
// AllowedNames contains the names the provisioner is authorized to sign
AllowedNames *SSHNameOptions `json:"allow,omitempty"`
// DeniedNames contains the names the provisioner is not authorized to sign
DeniedNames *SSHNameOptions `json:"deny,omitempty"`
}
// SSHHostCertificateOptions is a collection of SSH host certificate options.
// It's an alias of SSHUserCertificateOptions, as the options are the same
// for both types of certificates.
type SSHHostCertificateOptions SSHUserCertificateOptions
// SSHNameOptions models the SSH name policy configuration.
type SSHNameOptions struct {
DNSDomains []string `json:"dns,omitempty"`
IPRanges []string `json:"ip,omitempty"`
EmailAddresses []string `json:"email,omitempty"`
Principals []string `json:"principal,omitempty"`
}
// GetAllowedNameOptions returns the AllowedSSHNameOptions, which models the
// names that a provisioner is authorized to sign SSH certificates for.
func (o *SSHUserCertificateOptions) GetAllowedNameOptions() *SSHNameOptions {
if o == nil {
return nil
}
return o.AllowedNames
}
// GetDeniedNameOptions returns the DeniedSSHNameOptions, which models the
// names that a provisioner is NOT authorized to sign SSH certificates for.
func (o *SSHUserCertificateOptions) GetDeniedNameOptions() *SSHNameOptions {
if o == nil {
return nil
}
return o.DeniedNames
}
// HasNames checks if the SSHNameOptions has one or more
// names configured.
func (o *SSHNameOptions) HasNames() bool {
return len(o.DNSDomains) > 0 ||
len(o.IPRanges) > 0 ||
len(o.EmailAddresses) > 0 ||
len(o.Principals) > 0
}

@ -0,0 +1,45 @@
package policy
import (
"testing"
)
func TestX509PolicyOptions_IsWildcardLiteralAllowed(t *testing.T) {
tests := []struct {
name string
options *X509PolicyOptions
want bool
}{
{
name: "nil-options",
options: nil,
want: true,
},
{
name: "not-set",
options: &X509PolicyOptions{},
want: false,
},
{
name: "set-true",
options: &X509PolicyOptions{
AllowWildcardNames: true,
},
want: true,
},
{
name: "set-false",
options: &X509PolicyOptions{
AllowWildcardNames: false,
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.options.AreWildcardNamesAllowed(); got != tt.want {
t.Errorf("X509PolicyOptions.IsWildcardLiteralAllowed() = %v, want %v", got, tt.want)
}
})
}
}

@ -0,0 +1,256 @@
package policy
import (
"fmt"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/policy"
)
// X509Policy is an alias for policy.X509NamePolicyEngine
type X509Policy policy.X509NamePolicyEngine
// UserPolicy is an alias for policy.SSHNamePolicyEngine
type UserPolicy policy.SSHNamePolicyEngine
// HostPolicy is an alias for policy.SSHNamePolicyEngine
type HostPolicy policy.SSHNamePolicyEngine
// NewX509PolicyEngine creates a new x509 name policy engine
func NewX509PolicyEngine(policyOptions X509PolicyOptionsInterface) (X509Policy, error) {
// return early if no policy engine options to configure
if policyOptions == nil {
return nil, nil
}
options := []policy.NamePolicyOption{}
allowed := policyOptions.GetAllowedNameOptions()
if allowed != nil && allowed.HasNames() {
options = append(options,
policy.WithPermittedCommonNames(allowed.CommonNames...),
policy.WithPermittedDNSDomains(allowed.DNSDomains...),
policy.WithPermittedIPsOrCIDRs(allowed.IPRanges...),
policy.WithPermittedEmailAddresses(allowed.EmailAddresses...),
policy.WithPermittedURIDomains(allowed.URIDomains...),
)
}
denied := policyOptions.GetDeniedNameOptions()
if denied != nil && denied.HasNames() {
options = append(options,
policy.WithExcludedCommonNames(denied.CommonNames...),
policy.WithExcludedDNSDomains(denied.DNSDomains...),
policy.WithExcludedIPsOrCIDRs(denied.IPRanges...),
policy.WithExcludedEmailAddresses(denied.EmailAddresses...),
policy.WithExcludedURIDomains(denied.URIDomains...),
)
}
// ensure no policy engine is returned when no name options were provided
if len(options) == 0 {
return nil, nil
}
// check if configuration specifies that wildcard names are allowed
if policyOptions.AreWildcardNamesAllowed() {
options = append(options, policy.WithAllowLiteralWildcardNames())
}
// enable subject common name verification by default
options = append(options, policy.WithSubjectCommonNameVerification())
return policy.New(options...)
}
type sshPolicyEngineType string
const (
UserPolicyEngineType sshPolicyEngineType = "user"
HostPolicyEngineType sshPolicyEngineType = "host"
)
// newSSHUserPolicyEngine creates a new SSH user certificate policy engine
func NewSSHUserPolicyEngine(policyOptions SSHPolicyOptionsInterface) (UserPolicy, error) {
policyEngine, err := newSSHPolicyEngine(policyOptions, UserPolicyEngineType)
if err != nil {
return nil, err
}
return policyEngine, nil
}
// newSSHHostPolicyEngine create a new SSH host certificate policy engine
func NewSSHHostPolicyEngine(policyOptions SSHPolicyOptionsInterface) (HostPolicy, error) {
policyEngine, err := newSSHPolicyEngine(policyOptions, HostPolicyEngineType)
if err != nil {
return nil, err
}
return policyEngine, nil
}
// newSSHPolicyEngine creates a new SSH name policy engine
func newSSHPolicyEngine(policyOptions SSHPolicyOptionsInterface, typ sshPolicyEngineType) (policy.SSHNamePolicyEngine, error) {
// return early if no policy engine options to configure
if policyOptions == nil {
return nil, nil
}
var (
allowed *SSHNameOptions
denied *SSHNameOptions
)
switch typ {
case UserPolicyEngineType:
allowed = policyOptions.GetAllowedUserNameOptions()
denied = policyOptions.GetDeniedUserNameOptions()
case HostPolicyEngineType:
allowed = policyOptions.GetAllowedHostNameOptions()
denied = policyOptions.GetDeniedHostNameOptions()
default:
return nil, fmt.Errorf("unknown SSH policy engine type %s provided", typ)
}
options := []policy.NamePolicyOption{}
if allowed != nil && allowed.HasNames() {
options = append(options,
policy.WithPermittedDNSDomains(allowed.DNSDomains...),
policy.WithPermittedIPsOrCIDRs(allowed.IPRanges...),
policy.WithPermittedEmailAddresses(allowed.EmailAddresses...),
policy.WithPermittedPrincipals(allowed.Principals...),
)
}
if denied != nil && denied.HasNames() {
options = append(options,
policy.WithExcludedDNSDomains(denied.DNSDomains...),
policy.WithExcludedIPsOrCIDRs(denied.IPRanges...),
policy.WithExcludedEmailAddresses(denied.EmailAddresses...),
policy.WithExcludedPrincipals(denied.Principals...),
)
}
// ensure no policy engine is returned when no name options were provided
if len(options) == 0 {
return nil, nil
}
return policy.New(options...)
}
func LinkedToCertificates(p *linkedca.Policy) *Options {
// return early
if p == nil {
return nil
}
// return early if x509 nor SSH is set
if p.GetX509() == nil && p.GetSsh() == nil {
return nil
}
opts := &Options{}
// fill x509 policy configuration
if x509 := p.GetX509(); x509 != nil {
opts.X509 = &X509PolicyOptions{}
if allow := x509.GetAllow(); allow != nil {
opts.X509.AllowedNames = &X509NameOptions{}
if allow.Dns != nil {
opts.X509.AllowedNames.DNSDomains = allow.Dns
}
if allow.Ips != nil {
opts.X509.AllowedNames.IPRanges = allow.Ips
}
if allow.Emails != nil {
opts.X509.AllowedNames.EmailAddresses = allow.Emails
}
if allow.Uris != nil {
opts.X509.AllowedNames.URIDomains = allow.Uris
}
if allow.CommonNames != nil {
opts.X509.AllowedNames.CommonNames = allow.CommonNames
}
}
if deny := x509.GetDeny(); deny != nil {
opts.X509.DeniedNames = &X509NameOptions{}
if deny.Dns != nil {
opts.X509.DeniedNames.DNSDomains = deny.Dns
}
if deny.Ips != nil {
opts.X509.DeniedNames.IPRanges = deny.Ips
}
if deny.Emails != nil {
opts.X509.DeniedNames.EmailAddresses = deny.Emails
}
if deny.Uris != nil {
opts.X509.DeniedNames.URIDomains = deny.Uris
}
if deny.CommonNames != nil {
opts.X509.DeniedNames.CommonNames = deny.CommonNames
}
}
opts.X509.AllowWildcardNames = x509.GetAllowWildcardNames()
}
// fill ssh policy configuration
if ssh := p.GetSsh(); ssh != nil {
opts.SSH = &SSHPolicyOptions{}
if host := ssh.GetHost(); host != nil {
opts.SSH.Host = &SSHHostCertificateOptions{}
if allow := host.GetAllow(); allow != nil {
opts.SSH.Host.AllowedNames = &SSHNameOptions{}
if allow.Dns != nil {
opts.SSH.Host.AllowedNames.DNSDomains = allow.Dns
}
if allow.Ips != nil {
opts.SSH.Host.AllowedNames.IPRanges = allow.Ips
}
if allow.Principals != nil {
opts.SSH.Host.AllowedNames.Principals = allow.Principals
}
}
if deny := host.GetDeny(); deny != nil {
opts.SSH.Host.DeniedNames = &SSHNameOptions{}
if deny.Dns != nil {
opts.SSH.Host.DeniedNames.DNSDomains = deny.Dns
}
if deny.Ips != nil {
opts.SSH.Host.DeniedNames.IPRanges = deny.Ips
}
if deny.Principals != nil {
opts.SSH.Host.DeniedNames.Principals = deny.Principals
}
}
}
if user := ssh.GetUser(); user != nil {
opts.SSH.User = &SSHUserCertificateOptions{}
if allow := user.GetAllow(); allow != nil {
opts.SSH.User.AllowedNames = &SSHNameOptions{}
if allow.Emails != nil {
opts.SSH.User.AllowedNames.EmailAddresses = allow.Emails
}
if allow.Principals != nil {
opts.SSH.User.AllowedNames.Principals = allow.Principals
}
}
if deny := user.GetDeny(); deny != nil {
opts.SSH.User.DeniedNames = &SSHNameOptions{}
if deny.Emails != nil {
opts.SSH.User.DeniedNames.EmailAddresses = deny.Emails
}
if deny.Principals != nil {
opts.SSH.User.DeniedNames.Principals = deny.Principals
}
}
}
}
return opts
}

@ -0,0 +1,155 @@
package policy
import (
"testing"
"github.com/google/go-cmp/cmp"
"go.step.sm/linkedca"
)
func TestPolicyToCertificates(t *testing.T) {
type args struct {
policy *linkedca.Policy
}
tests := []struct {
name string
args args
want *Options
}{
{
name: "nil",
args: args{
policy: nil,
},
want: nil,
},
{
name: "no-policy",
args: args{
&linkedca.Policy{},
},
want: nil,
},
{
name: "partial-policy",
args: args{
&linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{
Dns: []string{"*.local"},
},
AllowWildcardNames: false,
},
},
},
want: &Options{
X509: &X509PolicyOptions{
AllowedNames: &X509NameOptions{
DNSDomains: []string{"*.local"},
},
AllowWildcardNames: false,
},
},
},
{
name: "full-policy",
args: args{
&linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{
Dns: []string{"step"},
Ips: []string{"127.0.0.1/24"},
Emails: []string{"*.example.com"},
Uris: []string{"https://*.local"},
CommonNames: []string{"some name"},
},
Deny: &linkedca.X509Names{
Dns: []string{"bad"},
Ips: []string{"127.0.0.30"},
Emails: []string{"badhost.example.com"},
Uris: []string{"https://badhost.local"},
CommonNames: []string{"another name"},
},
AllowWildcardNames: true,
},
Ssh: &linkedca.SSHPolicy{
Host: &linkedca.SSHHostPolicy{
Allow: &linkedca.SSHHostNames{
Dns: []string{"*.localhost"},
Ips: []string{"127.0.0.1/24"},
Principals: []string{"user"},
},
Deny: &linkedca.SSHHostNames{
Dns: []string{"badhost.localhost"},
Ips: []string{"127.0.0.40"},
Principals: []string{"root"},
},
},
User: &linkedca.SSHUserPolicy{
Allow: &linkedca.SSHUserNames{
Emails: []string{"@work"},
Principals: []string{"user"},
},
Deny: &linkedca.SSHUserNames{
Emails: []string{"root@work"},
Principals: []string{"root"},
},
},
},
},
},
want: &Options{
X509: &X509PolicyOptions{
AllowedNames: &X509NameOptions{
DNSDomains: []string{"step"},
IPRanges: []string{"127.0.0.1/24"},
EmailAddresses: []string{"*.example.com"},
URIDomains: []string{"https://*.local"},
CommonNames: []string{"some name"},
},
DeniedNames: &X509NameOptions{
DNSDomains: []string{"bad"},
IPRanges: []string{"127.0.0.30"},
EmailAddresses: []string{"badhost.example.com"},
URIDomains: []string{"https://badhost.local"},
CommonNames: []string{"another name"},
},
AllowWildcardNames: true,
},
SSH: &SSHPolicyOptions{
Host: &SSHHostCertificateOptions{
AllowedNames: &SSHNameOptions{
DNSDomains: []string{"*.localhost"},
IPRanges: []string{"127.0.0.1/24"},
Principals: []string{"user"},
},
DeniedNames: &SSHNameOptions{
DNSDomains: []string{"badhost.localhost"},
IPRanges: []string{"127.0.0.40"},
Principals: []string{"root"},
},
},
User: &SSHUserCertificateOptions{
AllowedNames: &SSHNameOptions{
EmailAddresses: []string{"@work"},
Principals: []string{"user"},
},
DeniedNames: &SSHNameOptions{
EmailAddresses: []string{"root@work"},
Principals: []string{"root"},
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := LinkedToCertificates(tt.args.policy)
if !cmp.Equal(tt.want, got) {
t.Errorf("policyToCertificates() diff=\n%s", cmp.Diff(tt.want, got))
}
})
}
}

File diff suppressed because it is too large Load Diff

@ -3,6 +3,8 @@ package provisioner
import (
"context"
"crypto/x509"
"fmt"
"net"
"time"
"github.com/pkg/errors"
@ -23,7 +25,8 @@ type ACME struct {
RequireEAB bool `json:"requireEAB,omitempty"`
Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"`
ctl *Controller
ctl *Controller
}
// GetID returns the provisioner unique identifier.
@ -71,7 +74,7 @@ func (p *ACME) DefaultTLSCertDuration() time.Duration {
return p.ctl.Claimer.DefaultTLSCertDuration()
}
// Init initializes and validates the fields of a JWK type.
// Init initializes and validates the fields of an ACME type.
func (p *ACME) Init(config Config) (err error) {
switch {
case p.Type == "":
@ -80,15 +83,57 @@ func (p *ACME) Init(config Config) (err error) {
return errors.New("provisioner name cannot be empty")
}
p.ctl, err = NewController(p, p.Claims, config)
p.ctl, err = NewController(p, p.Claims, config, p.Options)
return
}
// ACMEIdentifierType encodes ACME Identifier types
type ACMEIdentifierType string
const (
// IP is the ACME ip identifier type
IP ACMEIdentifierType = "ip"
// DNS is the ACME dns identifier type
DNS ACMEIdentifierType = "dns"
)
// ACMEIdentifier encodes ACME Order Identifiers
type ACMEIdentifier struct {
Type ACMEIdentifierType
Value string
}
// AuthorizeOrderIdentifier verifies the provisioner is allowed to issue a
// certificate for an ACME Order Identifier.
func (p *ACME) AuthorizeOrderIdentifier(ctx context.Context, identifier ACMEIdentifier) error {
x509Policy := p.ctl.getPolicy().getX509()
// identifier is allowed if no policy is configured
if x509Policy == nil {
return nil
}
// assuming only valid identifiers (IP or DNS) are provided
var err error
switch identifier.Type {
case IP:
err = x509Policy.IsIPAllowed(net.ParseIP(identifier.Value))
case DNS:
err = x509Policy.IsDNSAllowed(identifier.Value)
default:
err = fmt.Errorf("invalid ACME identifier type '%s' provided", identifier.Type)
}
return err
}
// AuthorizeSign does not do any validation, because all validation is handled
// in the ACME protocol. This method returns a list of modifiers / constraints
// on the resulting certificate.
func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
return []SignOption{
opts := []SignOption{
p,
// modifiers / withOptions
newProvisionerExtensionOption(TypeACME, p.Name, ""),
newForceCNOption(p.ForceCN),
@ -96,7 +141,10 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
// validators
defaultPublicKeyValidator{},
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
}, nil
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
}
return opts, nil
}
// AuthorizeRevoke is called just before the certificate is to be revoked by

@ -176,9 +176,10 @@ func TestACME_AuthorizeSign(t *testing.T) {
}
} else {
if assert.Nil(t, tc.err) && assert.NotNil(t, opts) {
assert.Len(t, 5, opts)
assert.Equals(t, 7, len(opts)) // number of SignOptions returned
for _, o := range opts {
switch v := o.(type) {
case *ACME:
case *provisionerExtensionOption:
assert.Equals(t, v.Type, TypeACME)
assert.Equals(t, v.Name, tc.p.GetName())
@ -192,6 +193,8 @@ func TestACME_AuthorizeSign(t *testing.T) {
case *validityValidator:
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine)
default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}

@ -17,10 +17,12 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/errs"
)
// awsIssuer is the string used as issuer in the generated tokens.
@ -49,22 +51,27 @@ const awsMetadataTokenTTLHeader = "X-aws-ec2-metadata-token-ttl-seconds"
// signature.
//
// The first certificate is used in:
// ap-northeast-2, ap-south-1, ap-southeast-1, ap-southeast-2
// eu-central-1, eu-north-1, eu-west-1, eu-west-2, eu-west-3
// us-east-1, us-east-2, us-west-1, us-west-2
// ca-central-1, sa-east-1
//
// ap-northeast-2, ap-south-1, ap-southeast-1, ap-southeast-2
// eu-central-1, eu-north-1, eu-west-1, eu-west-2, eu-west-3
// us-east-1, us-east-2, us-west-1, us-west-2
// ca-central-1, sa-east-1
//
// The second certificate is used in:
// eu-south-1
//
// eu-south-1
//
// The third certificate is used in:
// ap-east-1
//
// ap-east-1
//
// The fourth certificate is used in:
// af-south-1
//
// af-south-1
//
// The fifth certificate is used in:
// me-south-1
//
// me-south-1
const awsCertificate = `-----BEGIN CERTIFICATE-----
MIIDIjCCAougAwIBAgIJAKnL4UEDMN/FMA0GCSqGSIb3DQEBBQUAMGoxCzAJBgNV
BAYTAlVTMRMwEQYDVQQIEwpXYXNoaW5ndG9uMRAwDgYDVQQHEwdTZWF0dGxlMRgw
@ -421,7 +428,7 @@ func (p *AWS) Init(config Config) (err error) {
}
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.ctl, err = NewController(p, p.Claims, config)
p.ctl, err = NewController(p, p.Claims, config, p.Options)
return
}
@ -467,6 +474,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
}
return append(so,
p,
templateOptions,
// modifiers / withOptions
newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID, "InstanceID", doc.InstanceID),
@ -475,6 +483,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
defaultPublicKeyValidator{},
commonNameValidator(payload.Claims.Subject),
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
), nil
}
@ -542,7 +551,7 @@ func (p *AWS) readURL(url string) ([]byte, error) {
if err != nil {
return nil, err
}
return nil, fmt.Errorf("Request for metadata returned non-successful status code %d",
return nil, fmt.Errorf("request for metadata returned non-successful status code %d",
resp.StatusCode)
}
@ -575,7 +584,7 @@ func (p *AWS) readURLv2(url string) (*http.Response, error) {
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("Request for API token returned non-successful status code %d", resp.StatusCode)
return nil, fmt.Errorf("request for API token returned non-successful status code %d", resp.StatusCode)
}
token, err := io.ReadAll(resp.Body)
if err != nil {
@ -743,6 +752,7 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
signOptions = append(signOptions, templateOptions)
return append(signOptions,
p,
// Validate user SignSSHOptions.
sshCertOptionsValidator(defaults),
// Set the validity bounds if not set.
@ -753,5 +763,7 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
&sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate
&sshCertDefaultValidator{},
// Ensure that all principal names are allowed
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil),
), nil
}

@ -642,11 +642,11 @@ func TestAWS_AuthorizeSign(t *testing.T) {
code int
wantErr bool
}{
{"ok", p1, args{t1, "foo.local"}, 6, http.StatusOK, false},
{"ok", p2, args{t2, "instance-id"}, 10, http.StatusOK, false},
{"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 10, http.StatusOK, false},
{"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 10, http.StatusOK, false},
{"ok", p1, args{t4, "instance-id"}, 6, http.StatusOK, false},
{"ok", p1, args{t1, "foo.local"}, 8, http.StatusOK, false},
{"ok", p2, args{t2, "instance-id"}, 12, http.StatusOK, false},
{"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 12, http.StatusOK, false},
{"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 12, http.StatusOK, false},
{"ok", p1, args{t4, "instance-id"}, 8, http.StatusOK, false},
{"fail account", p3, args{token: t3}, 0, http.StatusUnauthorized, true},
{"fail token", p1, args{token: "token"}, 0, http.StatusUnauthorized, true},
{"fail subject", p1, args{token: failSubject}, 0, http.StatusUnauthorized, true},
@ -673,9 +673,10 @@ func TestAWS_AuthorizeSign(t *testing.T) {
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
default:
assert.Len(t, tt.wantLen, got)
assert.Equals(t, tt.wantLen, len(got))
for _, o := range got {
switch v := o.(type) {
case *AWS:
case certificateOptionsFunc:
case *provisionerExtensionOption:
assert.Equals(t, v.Type, TypeAWS)
@ -698,6 +699,8 @@ func TestAWS_AuthorizeSign(t *testing.T) {
assert.Equals(t, v, nil)
case dnsNamesValidator:
assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"})
case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine)
default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}
@ -810,7 +813,6 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
} else if assert.NotNil(t, got) {
cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
if (err != nil) != tt.wantSignErr {
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
} else {
if tt.wantSignErr {

@ -13,10 +13,12 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/errs"
)
// azureOIDCBaseURL is the base discovery url for Microsoft Azure tokens.
@ -219,7 +221,7 @@ func (p *Azure) Init(config Config) (err error) {
return
}
p.ctl, err = NewController(p, p.Claims, config)
p.ctl, err = NewController(p, p.Claims, config, p.Options)
return
}
@ -352,6 +354,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
}
return append(so,
p,
templateOptions,
// modifiers / withOptions
newProvisionerExtensionOption(TypeAzure, p.Name, p.TenantID),
@ -359,6 +362,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
// validators
defaultPublicKeyValidator{},
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
), nil
}
@ -414,6 +418,7 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
signOptions = append(signOptions, templateOptions)
return append(signOptions,
p,
// Validate user SignSSHOptions.
sshCertOptionsValidator(defaults),
// Set the validity bounds if not set.
@ -424,6 +429,8 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
&sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate
&sshCertDefaultValidator{},
// Ensure that all principal names are allowed
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil),
), nil
}

@ -474,11 +474,11 @@ func TestAzure_AuthorizeSign(t *testing.T) {
code int
wantErr bool
}{
{"ok", p1, args{t1}, 5, http.StatusOK, false},
{"ok", p2, args{t2}, 10, http.StatusOK, false},
{"ok", p1, args{t11}, 5, http.StatusOK, false},
{"ok", p5, args{t5}, 5, http.StatusOK, false},
{"ok", p7, args{t7}, 5, http.StatusOK, false},
{"ok", p1, args{t1}, 7, http.StatusOK, false},
{"ok", p2, args{t2}, 12, http.StatusOK, false},
{"ok", p1, args{t11}, 7, http.StatusOK, false},
{"ok", p5, args{t5}, 7, http.StatusOK, false},
{"ok", p7, args{t7}, 7, http.StatusOK, false},
{"fail tenant", p3, args{t3}, 0, http.StatusUnauthorized, true},
{"fail resource group", p4, args{t4}, 0, http.StatusUnauthorized, true},
{"fail subscription", p6, args{t6}, 0, http.StatusUnauthorized, true},
@ -502,9 +502,10 @@ func TestAzure_AuthorizeSign(t *testing.T) {
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
default:
assert.Len(t, tt.wantLen, got)
assert.Equals(t, tt.wantLen, len(got))
for _, o := range got {
switch v := o.(type) {
case *Azure:
case certificateOptionsFunc:
case *provisionerExtensionOption:
assert.Equals(t, v.Type, TypeAzure)
@ -527,6 +528,8 @@ func TestAzure_AuthorizeSign(t *testing.T) {
assert.Equals(t, v, nil)
case dnsNamesValidator:
assert.Equals(t, []string(v), []string{"virtualMachine"})
case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine)
default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}

@ -24,8 +24,8 @@ type Claims struct {
EnableSSHCA *bool `json:"enableSSHCA,omitempty"`
// Renewal properties
DisableRenewal *bool `json:"disableRenewal,omitempty"`
AllowRenewAfterExpiry *bool `json:"allowRenewAfterExpiry,omitempty"`
DisableRenewal *bool `json:"disableRenewal,omitempty"`
AllowRenewalAfterExpiry *bool `json:"allowRenewalAfterExpiry,omitempty"`
}
// Claimer is the type that controls claims. It provides an interface around the
@ -44,22 +44,22 @@ func NewClaimer(claims *Claims, global Claims) (*Claimer, error) {
// Claims returns the merge of the inner and global claims.
func (c *Claimer) Claims() Claims {
disableRenewal := c.IsDisableRenewal()
allowRenewAfterExpiry := c.AllowRenewAfterExpiry()
allowRenewalAfterExpiry := c.AllowRenewalAfterExpiry()
enableSSHCA := c.IsSSHCAEnabled()
return Claims{
MinTLSDur: &Duration{c.MinTLSCertDuration()},
MaxTLSDur: &Duration{c.MaxTLSCertDuration()},
DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()},
MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()},
MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()},
DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()},
MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()},
MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()},
DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()},
EnableSSHCA: &enableSSHCA,
DisableRenewal: &disableRenewal,
AllowRenewAfterExpiry: &allowRenewAfterExpiry,
MinTLSDur: &Duration{c.MinTLSCertDuration()},
MaxTLSDur: &Duration{c.MaxTLSCertDuration()},
DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()},
MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()},
MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()},
DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()},
MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()},
MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()},
DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()},
EnableSSHCA: &enableSSHCA,
DisableRenewal: &disableRenewal,
AllowRenewalAfterExpiry: &allowRenewalAfterExpiry,
}
}
@ -109,14 +109,14 @@ func (c *Claimer) IsDisableRenewal() bool {
return *c.claims.DisableRenewal
}
// AllowRenewAfterExpiry returns if the renewal flow is authorized if the
// AllowRenewalAfterExpiry returns if the renewal flow is authorized if the
// certificate is expired. If the property is not set within the provisioner
// then the global value from the authority configuration will be used.
func (c *Claimer) AllowRenewAfterExpiry() bool {
if c.claims == nil || c.claims.AllowRenewAfterExpiry == nil {
return *c.global.AllowRenewAfterExpiry
func (c *Claimer) AllowRenewalAfterExpiry() bool {
if c.claims == nil || c.claims.AllowRenewalAfterExpiry == nil {
return *c.global.AllowRenewalAfterExpiry
}
return *c.claims.AllowRenewAfterExpiry
return *c.claims.AllowRenewalAfterExpiry
}
// DefaultSSHCertDuration returns the default SSH certificate duration for the

@ -3,6 +3,7 @@ package provisioner
import (
"context"
"crypto/x509"
"net/http"
"regexp"
"strings"
"time"
@ -21,14 +22,19 @@ type Controller struct {
IdentityFunc GetIdentityFunc
AuthorizeRenewFunc AuthorizeRenewFunc
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
policy *policyEngine
}
// NewController initializes a new provisioner controller.
func NewController(p Interface, claims *Claims, config Config) (*Controller, error) {
func NewController(p Interface, claims *Claims, config Config, options *Options) (*Controller, error) {
claimer, err := NewClaimer(claims, config.Claims)
if err != nil {
return nil, err
}
policy, err := newPolicyEngine(options)
if err != nil {
return nil, err
}
return &Controller{
Interface: p,
Audiences: &config.Audiences,
@ -36,6 +42,7 @@ func NewController(p Interface, claims *Claims, config Config) (*Controller, err
IdentityFunc: config.GetIdentityFunc,
AuthorizeRenewFunc: config.AuthorizeRenewFunc,
AuthorizeSSHRenewFunc: config.AuthorizeSSHRenewFunc,
policy: policy,
}, nil
}
@ -124,8 +131,10 @@ func DefaultAuthorizeRenew(ctx context.Context, p *Controller, cert *x509.Certif
if now.Before(cert.NotBefore) {
return errs.Unauthorized("certificate is not yet valid" + " " + now.UTC().Format(time.RFC3339Nano) + " vs " + cert.NotBefore.Format(time.RFC3339Nano))
}
if now.After(cert.NotAfter) && !p.Claimer.AllowRenewAfterExpiry() {
return errs.Unauthorized("certificate has expired")
if now.After(cert.NotAfter) && !p.Claimer.AllowRenewalAfterExpiry() {
// return a custom 401 Unauthorized error with a clearer message for the client
// TODO(hs): these errors likely need to be refactored as a whole; HTTP status codes shouldn't be in this layer.
return errs.New(http.StatusUnauthorized, "The request lacked necessary authorization to be completed: certificate expired on %s", cert.NotAfter)
}
return nil
@ -144,7 +153,7 @@ func DefaultAuthorizeSSHRenew(ctx context.Context, p *Controller, cert *ssh.Cert
if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) {
return errs.Unauthorized("certificate is not yet valid")
}
if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) && !p.Claimer.AllowRenewAfterExpiry() {
if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) && !p.Claimer.AllowRenewalAfterExpiry() {
return errs.Unauthorized("certificate has expired")
}
@ -192,3 +201,10 @@ func SanitizeSSHUserPrincipal(email string) string {
}
}, strings.ToLower(email))
}
func (c *Controller) getPolicy() *policyEngine {
if c == nil {
return nil
}
return c.policy
}

@ -9,6 +9,8 @@ import (
"time"
"golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/authority/policy"
)
var trueValue = true
@ -30,11 +32,40 @@ func mustDuration(t *testing.T, s string) *Duration {
return d
}
func mustNewPolicyEngine(t *testing.T, options *Options) *policyEngine {
t.Helper()
c, err := newPolicyEngine(options)
if err != nil {
t.Fatal(err)
}
return c
}
func TestNewController(t *testing.T) {
options := &Options{
X509: &X509Options{
AllowedNames: &policy.X509NameOptions{
DNSDomains: []string{"*.local"},
},
},
SSH: &SSHOptions{
Host: &policy.SSHHostCertificateOptions{
AllowedNames: &policy.SSHNameOptions{
DNSDomains: []string{"*.local"},
},
},
User: &policy.SSHUserCertificateOptions{
AllowedNames: &policy.SSHNameOptions{
EmailAddresses: []string{"@example.com"},
},
},
},
}
type args struct {
p Interface
claims *Claims
config Config
p Interface
claims *Claims
config Config
options *Options
}
tests := []struct {
name string
@ -45,7 +76,7 @@ func TestNewController(t *testing.T) {
{"ok", args{&JWK{}, nil, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}}, &Controller{
}, nil}, &Controller{
Interface: &JWK{},
Audiences: &testAudiences,
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
@ -55,24 +86,49 @@ func TestNewController(t *testing.T) {
}, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}}, &Controller{
}, nil}, &Controller{
Interface: &JWK{},
Audiences: &testAudiences,
Claimer: mustClaimer(t, &Claims{
DisableRenewal: &defaultDisableRenewal,
}, globalProvisionerClaims),
}, false},
{"ok with claims and options", args{&JWK{}, &Claims{
DisableRenewal: &defaultDisableRenewal,
}, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}, options}, &Controller{
Interface: &JWK{},
Audiences: &testAudiences,
Claimer: mustClaimer(t, &Claims{
DisableRenewal: &defaultDisableRenewal,
}, globalProvisionerClaims),
policy: mustNewPolicyEngine(t, options),
}, false},
{"fail claimer", args{&JWK{}, &Claims{
MinTLSDur: mustDuration(t, "24h"),
MaxTLSDur: mustDuration(t, "2h"),
}, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}, nil}, nil, true},
{"fail options", args{&JWK{}, &Claims{
DisableRenewal: &defaultDisableRenewal,
}, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}, &Options{
X509: &X509Options{
AllowedNames: &policy.X509NameOptions{
DNSDomains: []string{"**.local"},
},
},
}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewController(tt.args.p, tt.args.claims, tt.args.config)
got, err := NewController(tt.args.p, tt.args.claims, tt.args.config, tt.args.options)
if (err != nil) != tt.wantErr {
t.Errorf("NewController() error = %v, wantErr %v", err, tt.wantErr)
return
@ -160,13 +216,13 @@ func TestController_AuthorizeRenew(t *testing.T) {
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, false},
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
return nil
}}, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, false},
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute),
}}, false},
@ -231,13 +287,13 @@ func TestController_AuthorizeSSHRenew(t *testing.T) {
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, false},
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
return nil
}}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, false},
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
}}, false},
@ -296,7 +352,7 @@ func TestDefaultAuthorizeRenew(t *testing.T) {
}}, false},
{"ok renew after expiry", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims),
Claimer: mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims),
}, &x509.Certificate{
NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute),
@ -354,7 +410,7 @@ func TestDefaultAuthorizeSSHRenew(t *testing.T) {
}}, false},
{"ok renew after expiry", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims),
Claimer: mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims),
}, &ssh.Certificate{
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
ValidBefore: uint64(now.Add(-time.Minute).Unix()),

@ -14,10 +14,12 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/errs"
)
// gcpCertsURL is the url that serves Google OAuth2 public keys.
@ -212,7 +214,7 @@ func (p *GCP) Init(config Config) (err error) {
}
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.ctl, err = NewController(p, p.Claims, config)
p.ctl, err = NewController(p, p.Claims, config, p.Options)
return
}
@ -262,6 +264,7 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
}
return append(so,
p,
templateOptions,
// modifiers / withOptions
newProvisionerExtensionOption(TypeGCP, p.Name, claims.Subject, "InstanceID", ce.InstanceID, "InstanceName", ce.InstanceName),
@ -269,6 +272,7 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// validators
defaultPublicKeyValidator{},
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
), nil
}
@ -421,6 +425,7 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
signOptions = append(signOptions, templateOptions)
return append(signOptions,
p,
// Validate user SignSSHOptions.
sshCertOptionsValidator(defaults),
// Set the validity bounds if not set.
@ -431,5 +436,7 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
&sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate
&sshCertDefaultValidator{},
// Ensure that all principal names are allowed
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil),
), nil
}

@ -516,9 +516,9 @@ func TestGCP_AuthorizeSign(t *testing.T) {
code int
wantErr bool
}{
{"ok", p1, args{t1}, 5, http.StatusOK, false},
{"ok", p2, args{t2}, 10, http.StatusOK, false},
{"ok", p3, args{t3}, 5, http.StatusOK, false},
{"ok", p1, args{t1}, 7, http.StatusOK, false},
{"ok", p2, args{t2}, 12, http.StatusOK, false},
{"ok", p3, args{t3}, 7, http.StatusOK, false},
{"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true},
{"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true},
{"fail iss", p1, args{failIss}, 0, http.StatusUnauthorized, true},
@ -545,9 +545,10 @@ func TestGCP_AuthorizeSign(t *testing.T) {
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
default:
assert.Len(t, tt.wantLen, got)
assert.Equals(t, tt.wantLen, len(got))
for _, o := range got {
switch v := o.(type) {
case *GCP:
case certificateOptionsFunc:
case *provisionerExtensionOption:
assert.Equals(t, v.Type, TypeGCP)
@ -570,6 +571,8 @@ func TestGCP_AuthorizeSign(t *testing.T) {
assert.Equals(t, v, nil)
case dnsNamesValidator:
assert.Equals(t, []string(v), []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"})
case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine)
default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}

@ -7,10 +7,12 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/errs"
)
// jwtPayload extends jwt.Claims with step attributes.
@ -97,7 +99,7 @@ func (p *JWK) Init(config Config) (err error) {
return errors.New("provisioner key cannot be empty")
}
p.ctl, err = NewController(p, p.Claims, config)
p.ctl, err = NewController(p, p.Claims, config, p.Options)
return
}
@ -141,6 +143,7 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err
// revoke the certificate with serial number in the `sub` property.
func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.ctl.Audiences.Revoke)
// TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRevoke)
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke")
}
@ -170,6 +173,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
}
return []SignOption{
p,
templateOptions,
// modifiers / withOptions
newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID),
@ -179,6 +183,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
defaultPublicKeyValidator{},
defaultSANsValidator(claims.SANs),
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
}, nil
}
@ -187,6 +192,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// revocation status. Just confirms that the provisioner that created the
// certificate was configured to allow renewals.
func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
// TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRewew and AuthorizeSSHRenew)
return p.ctl.AuthorizeRenew(ctx, cert)
}
@ -251,6 +257,7 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
}
return append(signOptions,
p,
// Set the validity bounds if not set.
&sshDefaultDuration{p.ctl.Claimer},
// Validate public key
@ -259,11 +266,14 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
&sshCertValidityValidator{p.ctl.Claimer},
// Require and validate all the default fields in the SSH certificate.
&sshCertDefaultValidator{},
// Ensure that all principal names are allowed
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), p.ctl.getPolicy().getSSHUser()),
), nil
}
// AuthorizeSSHRevoke returns nil if the token is valid, false otherwise.
func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke)
// TODO(hs): authorize the principals using SSH name policy allow/deny rules (also for other provisioners with AuthorizeSSHRevoke)
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke")
}

@ -297,9 +297,10 @@ func TestJWK_AuthorizeSign(t *testing.T) {
}
} else {
if assert.NotNil(t, got) {
assert.Len(t, 7, got)
assert.Equals(t, 9, len(got))
for _, o := range got {
switch v := o.(type) {
case *JWK:
case certificateOptionsFunc:
case *provisionerExtensionOption:
assert.Equals(t, v.Type, TypeJWK)
@ -316,6 +317,8 @@ func TestJWK_AuthorizeSign(t *testing.T) {
assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration())
case defaultSANsValidator:
assert.Equals(t, []string(v), tt.sans)
case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine)
default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}

@ -10,11 +10,13 @@ import (
"net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil"
"go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/errs"
)
// NOTE: There can be at most one kubernetes service account provisioner configured
@ -91,6 +93,7 @@ func (p *K8sSA) GetEncryptedKey() (string, string, bool) {
// Init initializes and validates the fields of a K8sSA type.
func (p *K8sSA) Init(config Config) (err error) {
switch {
case p.Type == "":
return errors.New("provisioner type cannot be empty")
@ -137,7 +140,7 @@ func (p *K8sSA) Init(config Config) (err error) {
p.kauthn = k8s.AuthenticationV1()
*/
p.ctl, err = NewController(p, p.Claims, config)
p.ctl, err = NewController(p, p.Claims, config, p.Options)
return
}
@ -231,6 +234,7 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
}
return []SignOption{
p,
templateOptions,
// modifiers / withOptions
newProvisionerExtensionOption(TypeK8sSA, p.Name, ""),
@ -238,6 +242,7 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
// validators
defaultPublicKeyValidator{},
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
}, nil
}
@ -270,6 +275,7 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
signOptions := []SignOption{templateOptions}
return append(signOptions,
p,
// Require type, key-id and principals in the SignSSHOptions.
&sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true},
// Set the validity bounds if not set.
@ -280,6 +286,8 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
&sshCertValidityValidator{p.ctl.Claimer},
// Require and validate all the default fields in the SSH certificate.
&sshCertDefaultValidator{},
// Ensure that all principal names are allowed
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), p.ctl.getPolicy().getSSHUser()),
), nil
}

@ -280,9 +280,9 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
} else {
if assert.Nil(t, tc.err) {
if assert.NotNil(t, opts) {
tot := 0
for _, o := range opts {
switch v := o.(type) {
case *K8sSA:
case certificateOptionsFunc:
case *provisionerExtensionOption:
assert.Equals(t, v.Type, TypeK8sSA)
@ -295,12 +295,13 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
case *validityValidator:
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine)
default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}
tot++
}
assert.Equals(t, tot, 5)
assert.Equals(t, 7, len(opts))
}
}
}
@ -367,9 +368,10 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
} else {
if assert.Nil(t, tc.err) {
if assert.NotNil(t, opts) {
tot := 0
assert.Len(t, 8, opts)
for _, o := range opts {
switch v := o.(type) {
case Interface:
case sshCertificateOptionsFunc:
case *sshCertOptionsRequireValidator:
assert.Equals(t, v, &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true})
@ -379,12 +381,13 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
case *sshCertDefaultValidator:
case *sshDefaultDuration:
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
case *sshNamePolicyValidator:
assert.Equals(t, nil, v.userPolicyEngine)
assert.Equals(t, nil, v.hostPolicyEngine)
default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}
tot++
}
assert.Equals(t, tot, 6)
}
}
}

@ -61,3 +61,16 @@ func MethodFromContext(ctx context.Context) Method {
m, _ := ctx.Value(methodKey{}).(Method)
return m
}
type tokenKey struct{}
// NewContextWithToken creates a new context with the given token.
func NewContextWithToken(ctx context.Context, token string) context.Context {
return context.WithValue(ctx, tokenKey{}, token)
}
// TokenFromContext returns the token stored in the given context.
func TokenFromContext(ctx context.Context) (string, bool) {
token, ok := ctx.Value(tokenKey{}).(string)
return token, ok
}

@ -10,12 +10,14 @@ import (
"github.com/pkg/errors"
nebula "github.com/slackhq/nebula/cert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x25519"
"go.step.sm/crypto/x509util"
"golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/errs"
)
const (
@ -61,7 +63,7 @@ func (p *Nebula) Init(config Config) (err error) {
}
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.ctl, err = NewController(p, p.Claims, config)
p.ctl, err = NewController(p, p.Claims, config, p.Options)
return
}
@ -144,6 +146,7 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
}
return []SignOption{
p,
templateOptions,
// modifiers / withOptions
newProvisionerExtensionOption(TypeNebula, p.Name, ""),
@ -160,6 +163,7 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
},
defaultPublicKeyValidator{},
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
}, nil
}
@ -246,6 +250,7 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti
}
return append(signOptions,
p,
templateOptions,
// Checks the validity bounds, and set the validity if has not been set.
&sshLimitDuration{p.ctl.Claimer, crt.Details.NotAfter},
@ -255,6 +260,8 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti
&sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate
&sshCertDefaultValidator{},
// Ensure that all principal names are allowed
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil),
), nil
}

@ -38,7 +38,7 @@ func (p *noop) Init(config Config) error {
}
func (p *noop) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
return []SignOption{}, nil
return []SignOption{p}, nil
}
func (p *noop) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
@ -50,7 +50,7 @@ func (p *noop) AuthorizeRevoke(ctx context.Context, token string) error {
}
func (p *noop) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
return []SignOption{}, nil
return []SignOption{p}, nil
}
func (p *noop) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save