Merge branch 'master' into context-authority

pull/914/head
Mariano Cano 2 years ago
commit d461918eb0

@ -1,4 +1,20 @@
### Description
Please describe your pull request.
<!---
Please provide answers in the spaces below each prompt, where applicable.
Not every PR requires responses for each prompt.
Use your discretion.
-->
#### Name of feature:
#### Pain or issue this feature alleviates:
#### Why is this important to the project (if not answered above):
#### Is there documentation on how to use this feature? If so, where?
#### In what environments or workflows is this feature supported?
#### In what environments or workflows is this feature explicitly NOT supported (if any)?
#### Supporting links/other PRs/issues:
💔Thank you!

@ -7,6 +7,8 @@ import (
"time"
"go.step.sm/crypto/jose"
"github.com/smallstep/certificates/authority/policy"
)
// Account is a subset of the internal account type containing only those
@ -43,15 +45,63 @@ func KeyToID(jwk *jose.JSONWebKey) (string, error) {
return base64.RawURLEncoding.EncodeToString(kid), nil
}
// PolicyNames contains ACME account level policy names
type PolicyNames struct {
DNSNames []string `json:"dns"`
IPRanges []string `json:"ips"`
}
// X509Policy contains ACME account level X.509 policy
type X509Policy struct {
Allowed PolicyNames `json:"allow"`
Denied PolicyNames `json:"deny"`
AllowWildcardNames bool `json:"allowWildcardNames"`
}
// Policy is an ACME Account level policy
type Policy struct {
X509 X509Policy `json:"x509"`
}
func (p *Policy) GetAllowedNameOptions() *policy.X509NameOptions {
if p == nil {
return nil
}
return &policy.X509NameOptions{
DNSDomains: p.X509.Allowed.DNSNames,
IPRanges: p.X509.Allowed.IPRanges,
}
}
func (p *Policy) GetDeniedNameOptions() *policy.X509NameOptions {
if p == nil {
return nil
}
return &policy.X509NameOptions{
DNSDomains: p.X509.Denied.DNSNames,
IPRanges: p.X509.Denied.IPRanges,
}
}
// AreWildcardNamesAllowed returns if wildcard names
// like *.example.com are allowed to be signed.
// Defaults to false.
func (p *Policy) AreWildcardNamesAllowed() bool {
if p == nil {
return false
}
return p.X509.AllowWildcardNames
}
// ExternalAccountKey is an ACME External Account Binding key.
type ExternalAccountKey struct {
ID string `json:"id"`
ProvisionerID string `json:"provisionerID"`
Reference string `json:"reference"`
AccountID string `json:"-"`
KeyBytes []byte `json:"-"`
HmacKey []byte `json:"-"`
CreatedAt time.Time `json:"createdAt"`
BoundAt time.Time `json:"boundAt,omitempty"`
Policy *Policy `json:"policy,omitempty"`
}
// AlreadyBound returns whether this EAK is already bound to
@ -68,6 +118,6 @@ func (eak *ExternalAccountKey) BindTo(account *Account) error {
}
eak.AccountID = account.ID
eak.BoundAt = time.Now()
eak.KeyBytes = []byte{} // clearing the key bytes; can only be used once
eak.HmacKey = []byte{} // clearing the key bytes; can only be used once
return nil
}

@ -7,8 +7,9 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"go.step.sm/crypto/jose"
"github.com/smallstep/assert"
)
func TestKeyToID(t *testing.T) {
@ -95,7 +96,7 @@ func TestExternalAccountKey_BindTo(t *testing.T) {
ID: "eakID",
ProvisionerID: "provID",
Reference: "ref",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
},
acct: &Account{
ID: "accountID",
@ -108,7 +109,7 @@ func TestExternalAccountKey_BindTo(t *testing.T) {
ID: "eakID",
ProvisionerID: "provID",
Reference: "ref",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
AccountID: "someAccountID",
BoundAt: boundAt,
},
@ -138,7 +139,7 @@ func TestExternalAccountKey_BindTo(t *testing.T) {
assert.Equals(t, ae.Subproblems, tt.err.Subproblems)
} else {
assert.Equals(t, eak.AccountID, acct.ID)
assert.Equals(t, eak.KeyBytes, []byte{})
assert.Equals(t, eak.HmacKey, []byte{})
assert.NotNil(t, eak.BoundAt)
}
})

@ -134,8 +134,7 @@ func NewAccount(w http.ResponseWriter, r *http.Request) {
}
if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response
err := eak.BindTo(acc)
if err != nil {
if err := eak.BindTo(acc); err != nil {
render.Error(w, err)
return
}

@ -13,10 +13,12 @@ import (
"github.com/go-chi/chi"
"github.com/pkg/errors"
"go.step.sm/crypto/jose"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/crypto/jose"
)
var (
@ -29,6 +31,22 @@ var (
}
)
type fakeProvisioner struct{}
func (*fakeProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error {
return nil
}
func (*fakeProvisioner) AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) {
return nil, nil
}
func (*fakeProvisioner) AuthorizeRevoke(ctx context.Context, token string) error { return nil }
func (*fakeProvisioner) GetID() string { return "" }
func (*fakeProvisioner) GetName() string { return "" }
func (*fakeProvisioner) DefaultTLSCertDuration() time.Duration { return 0 }
func (*fakeProvisioner) GetOptions() *provisioner.Options { return nil }
func newProv() acme.Provisioner {
// Initialize provisioners
p := &provisioner.ACME{
@ -41,6 +59,19 @@ func newProv() acme.Provisioner {
return p
}
func newProvWithOptions(options *provisioner.Options) acme.Provisioner {
// Initialize provisioners
p := &provisioner.ACME{
Type: "ACME",
Name: "test@acme-<test>provisioner.com",
Options: options,
}
if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
fmt.Printf("%v", err)
}
return p
}
func newACMEProv(t *testing.T) *provisioner.ACME {
p := newProv()
a, ok := p.(*provisioner.ACME)
@ -50,6 +81,15 @@ func newACMEProv(t *testing.T) *provisioner.ACME {
return a
}
func newACMEProvWithOptions(t *testing.T, options *provisioner.Options) *provisioner.ACME {
p := newProvWithOptions(options)
a, ok := p.(*provisioner.ACME)
if !ok {
t.Fatal("not a valid ACME provisioner")
}
return a
}
func createEABJWS(jwk *jose.JSONWebKey, hmacKey []byte, keyID, u string) (*jose.JSONWebSignature, error) {
signer, err := jose.NewSigner(
jose.SigningKey{
@ -507,16 +547,9 @@ func TestHandler_NewAccount(t *testing.T) {
}
b, err := json.Marshal(nar)
assert.FatalError(t, err)
scepProvisioner := &provisioner.SCEP{
Type: "SCEP",
Name: "test@scep-<test>provisioner.com",
}
if err := scepProvisioner.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
assert.FatalError(t, err)
}
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = acme.NewProvisionerContext(ctx, scepProvisioner)
ctx = acme.NewProvisionerContext(ctx, &fakeProvisioner{})
return test{
db: &acme.MockDB{},
ctx: ctx,
@ -563,7 +596,7 @@ func TestHandler_NewAccount(t *testing.T) {
ID: "eakID",
ProvisionerID: provID,
Reference: "testeak",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Now(),
}
return test{
@ -737,7 +770,7 @@ func TestHandler_NewAccount(t *testing.T) {
ID: "eakID",
ProvisionerID: provID,
Reference: "testeak",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Now(),
}, nil
},

@ -4,8 +4,9 @@ import (
"context"
"encoding/json"
"github.com/smallstep/certificates/acme"
"go.step.sm/crypto/jose"
"github.com/smallstep/certificates/acme"
)
// ExternalAccountBinding represents the ACME externalAccountBinding JWS
@ -56,11 +57,19 @@ func validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest)
return nil, acme.WrapErrorISE(err, "error retrieving external account key")
}
if externalAccountKey == nil {
return nil, acme.NewError(acme.ErrorUnauthorizedType, "the field 'kid' references an unknown key")
}
if len(externalAccountKey.HmacKey) == 0 {
return nil, acme.NewError(acme.ErrorServerInternalType, "external account binding key with id '%s' does not have secret bytes", keyID)
}
if externalAccountKey.AlreadyBound() {
return nil, acme.NewError(acme.ErrorUnauthorizedType, "external account binding key with id '%s' was already bound to account '%s' on %s", keyID, externalAccountKey.AccountID, externalAccountKey.BoundAt)
}
payload, err := eabJWS.Verify(externalAccountKey.KeyBytes)
payload, err := eabJWS.Verify(externalAccountKey.HmacKey)
if err != nil {
return nil, acme.WrapErrorISE(err, "error verifying externalAccountBinding signature")
}

@ -9,10 +9,11 @@ import (
"time"
"github.com/pkg/errors"
"go.step.sm/crypto/jose"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/crypto/jose"
)
func Test_keysAreEqual(t *testing.T) {
@ -152,7 +153,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
ID: "eakID",
ProvisionerID: provID,
Reference: "testeak",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: createdAt,
}, nil
},
@ -166,7 +167,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
ID: "eakID",
ProvisionerID: provID,
Reference: "testeak",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: createdAt,
},
err: nil,
@ -187,16 +188,10 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
}
b, err := json.Marshal(nar)
assert.FatalError(t, err)
scepProvisioner := &provisioner.SCEP{
Type: "SCEP",
Name: "test@scep-<test>provisioner.com",
}
if err := scepProvisioner.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
assert.FatalError(t, err)
}
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = acme.NewProvisionerContext(ctx, scepProvisioner)
ctx = acme.NewProvisionerContext(ctx, &fakeProvisioner{})
return test{
ctx: ctx,
err: acme.NewError(acme.ErrorServerInternalType, "could not load ACME provisioner from context: provisioner in context is not an ACME provisioner"),
@ -418,6 +413,112 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
err: acme.NewErrorISE("error retrieving external account key"),
}
},
"fail/db.GetExternalAccountKey-nil": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName)
rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url)
assert.FatalError(t, err)
eab := &ExternalAccountBinding{}
err = json.Unmarshal(rawEABJWS, &eab)
assert.FatalError(t, err)
nar := &NewAccountRequest{
Contact: []string{"foo", "bar"},
ExternalAccountBinding: eab,
}
payloadBytes, err := json.Marshal(nar)
assert.FatalError(t, err)
so := new(jose.SignerOptions)
so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm))
so.WithHeader("url", url)
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
Key: jwk.Key,
}, so)
assert.FatalError(t, err)
jws, err := signer.Sign(payloadBytes)
assert.FatalError(t, err)
raw, err := jws.CompactSerialize()
assert.FatalError(t, err)
parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err)
prov := newACMEProv(t)
prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
db: &acme.MockDB{
MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) {
return nil, nil
},
},
ctx: ctx,
nar: &NewAccountRequest{
Contact: []string{"foo", "bar"},
ExternalAccountBinding: eab,
},
eak: nil,
err: acme.NewError(acme.ErrorUnauthorizedType, "the field 'kid' references an unknown key"),
}
},
"fail/db.GetExternalAccountKey-no-keybytes": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName)
rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url)
assert.FatalError(t, err)
eab := &ExternalAccountBinding{}
err = json.Unmarshal(rawEABJWS, &eab)
assert.FatalError(t, err)
nar := &NewAccountRequest{
Contact: []string{"foo", "bar"},
ExternalAccountBinding: eab,
}
payloadBytes, err := json.Marshal(nar)
assert.FatalError(t, err)
so := new(jose.SignerOptions)
so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm))
so.WithHeader("url", url)
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
Key: jwk.Key,
}, so)
assert.FatalError(t, err)
jws, err := signer.Sign(payloadBytes)
assert.FatalError(t, err)
raw, err := jws.CompactSerialize()
assert.FatalError(t, err)
parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err)
prov := newACMEProv(t)
prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
createdAt := time.Now()
return test{
db: &acme.MockDB{
MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) {
return &acme.ExternalAccountKey{
ID: "eakID",
ProvisionerID: provID,
Reference: "testeak",
CreatedAt: createdAt,
AccountID: "some-account-id",
HmacKey: []byte{},
}, nil
},
},
ctx: ctx,
nar: &NewAccountRequest{
Contact: []string{"foo", "bar"},
ExternalAccountBinding: eab,
},
eak: nil,
err: acme.NewError(acme.ErrorServerInternalType, "external account binding key with id 'eakID' does not have secret bytes"),
}
},
"fail/db.GetExternalAccountKey-wrong-provisioner": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
@ -510,6 +611,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
Reference: "testeak",
CreatedAt: createdAt,
AccountID: "some-account-id",
HmacKey: []byte{1, 3, 3, 7},
BoundAt: boundAt,
}, nil
},
@ -564,7 +666,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
ID: "eakID",
ProvisionerID: provID,
Reference: "testeak",
KeyBytes: []byte{1, 2, 3, 4},
HmacKey: []byte{1, 2, 3, 4},
CreatedAt: time.Now(),
}, nil
},
@ -621,7 +723,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
ID: "eakID",
ProvisionerID: provID,
Reference: "testeak",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Now(),
}, nil
},
@ -675,7 +777,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
ID: "eakID",
ProvisionerID: provID,
Reference: "testeak",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Now(),
}, nil
},
@ -730,7 +832,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
ID: "eakID",
ProvisionerID: provID,
Reference: "testeak",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Now(),
}, nil
},
@ -771,7 +873,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
} else {
assert.NotNil(t, tc.eak)
assert.Equals(t, got.ID, tc.eak.ID)
assert.Equals(t, got.KeyBytes, tc.eak.KeyBytes)
assert.Equals(t, got.HmacKey, tc.eak.HmacKey)
assert.Equals(t, got.ProvisionerID, tc.eak.ProvisionerID)
assert.Equals(t, got.Reference, tc.eak.Reference)
assert.Equals(t, got.CreatedAt, tc.eak.CreatedAt)

@ -20,7 +20,6 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil"
)
@ -93,11 +92,7 @@ func TestHandler_GetDirectory(t *testing.T) {
}
},
"fail/different-provisioner": func(t *testing.T) test {
prov := &provisioner.SCEP{
Type: "SCEP",
Name: "test@scep-<test>provisioner.com",
}
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx := acme.NewProvisionerContext(context.Background(), &fakeProvisioner{})
return test{
ctx: ctx,
statusCode: 500,

@ -16,6 +16,8 @@ import (
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/authority/provisioner"
)
// NewOrderRequest represents the body for a NewOrder request.
@ -37,6 +39,8 @@ func (n *NewOrderRequest) Validate() error {
if id.Type == acme.IP && net.ParseIP(id.Value) == nil {
return acme.NewError(acme.ErrorMalformedType, "invalid IP address: %s", id.Value)
}
// TODO(hs): add some validations for DNS domains?
// TODO(hs): combine the errors from this with allow/deny policy, like example error in https://datatracker.ietf.org/doc/html/rfc8555#section-6.7.1
}
return nil
}
@ -70,6 +74,7 @@ var defaultOrderBackdate = time.Minute
// NewOrder ACME api for creating a new order.
func NewOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ca := mustAuthority(ctx)
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
@ -88,6 +93,7 @@ func NewOrder(w http.ResponseWriter, r *http.Request) {
render.Error(w, err)
return
}
var nor NewOrderRequest
if err := json.Unmarshal(payload.value, &nor); err != nil {
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
@ -100,6 +106,48 @@ func NewOrder(w http.ResponseWriter, r *http.Request) {
return
}
// TODO(hs): gather all errors, so that we can build one response with ACME subproblems
// include the nor.Validate() error here too, like in the example in the ACME RFC?
acmeProv, err := acmeProvisionerFromContext(ctx)
if err != nil {
render.Error(w, err)
return
}
var eak *acme.ExternalAccountKey
if acmeProv.RequireEAB {
if eak, err = db.GetExternalAccountKeyByAccountID(ctx, prov.GetID(), acc.ID); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving external account binding key"))
return
}
}
acmePolicy, err := newACMEPolicyEngine(eak)
if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error creating ACME policy engine"))
return
}
for _, identifier := range nor.Identifiers {
// evaluate the ACME account level policy
if err = isIdentifierAllowed(acmePolicy, identifier); err != nil {
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
return
}
// evaluate the provisioner level policy
orderIdentifier := provisioner.ACMEIdentifier{Type: provisioner.ACMEIdentifierType(identifier.Type), Value: identifier.Value}
if err = prov.AuthorizeOrderIdentifier(ctx, orderIdentifier); err != nil {
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
return
}
// evaluate the authority level policy
if err = ca.AreSANsAllowed(ctx, []string{identifier.Value}); err != nil {
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
return
}
}
now := clock.Now()
// New order.
o := &acme.Order{
@ -150,6 +198,20 @@ func NewOrder(w http.ResponseWriter, r *http.Request) {
render.JSONStatus(w, o, http.StatusCreated)
}
func isIdentifierAllowed(acmePolicy policy.X509Policy, identifier acme.Identifier) error {
if acmePolicy == nil {
return nil
}
return acmePolicy.AreSANsAllowed([]string{identifier.Value})
}
func newACMEPolicyEngine(eak *acme.ExternalAccountKey) (policy.X509Policy, error) {
if eak == nil {
return nil, nil
}
return policy.NewX509PolicyEngine(eak.Policy)
}
func newAuthorization(ctx context.Context, az *acme.Authorization) error {
if strings.HasPrefix(az.Identifier.Value, "*.") {
az.Wildcard = true
@ -231,7 +293,7 @@ func GetOrder(w http.ResponseWriter, r *http.Request) {
render.JSON(w, o)
}
// FinalizeOrder attemptst to finalize an order and create a certificate.
// FinalizeOrder attempts to finalize an order and create a certificate.
func FinalizeOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)

@ -16,9 +16,13 @@ import (
"github.com/go-chi/chi"
"github.com/pkg/errors"
"go.step.sm/crypto/pemutil"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"go.step.sm/crypto/pemutil"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/authority/provisioner"
)
func TestNewOrderRequest_Validate(t *testing.T) {
@ -670,6 +674,7 @@ func TestHandler_NewOrder(t *testing.T) {
baseURL.String(), escProvName)
type test struct {
ca acme.CertificateAuthority
db acme.DB
ctx context.Context
nor *NewOrderRequest
@ -737,7 +742,7 @@ func TestHandler_NewOrder(t *testing.T) {
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("paylod does not exist"),
err: acme.NewErrorISE("payload does not exist"),
}
},
"fail/unmarshal-payload-error": func(t *testing.T) test {
@ -767,6 +772,222 @@ func TestHandler_NewOrder(t *testing.T) {
err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"),
}
},
"fail/acmeProvisionerFromContext-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
fr := &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "zap.internal"},
},
}
b, err := json.Marshal(fr)
assert.FatalError(t, err)
ctx := acme.NewProvisionerContext(context.Background(), &acme.MockProvisioner{})
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 500,
ca: &mockCA{},
db: &acme.MockDB{
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, errors.New("force")
},
},
err: acme.NewErrorISE("error retrieving external account binding key: force"),
}
},
"fail/db.GetExternalAccountKeyByAccountID-error": func(t *testing.T) test {
acmeProv := newACMEProv(t)
acmeProv.RequireEAB = true
acc := &acme.Account{ID: "accID"}
fr := &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "zap.internal"},
},
}
b, err := json.Marshal(fr)
assert.FatalError(t, err)
ctx := acme.NewProvisionerContext(context.Background(), acmeProv)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 500,
ca: &mockCA{},
db: &acme.MockDB{
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, errors.New("force")
},
},
err: acme.NewErrorISE("error retrieving external account binding key: force"),
}
},
"fail/newACMEPolicyEngine-error": func(t *testing.T) test {
acmeProv := newACMEProv(t)
acmeProv.RequireEAB = true
acc := &acme.Account{ID: "accID"}
fr := &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "zap.internal"},
},
}
b, err := json.Marshal(fr)
assert.FatalError(t, err)
ctx := acme.NewProvisionerContext(context.Background(), acmeProv)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 500,
ca: &mockCA{},
db: &acme.MockDB{
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return &acme.ExternalAccountKey{
Policy: &acme.Policy{
X509: acme.X509Policy{
Allowed: acme.PolicyNames{
DNSNames: []string{"**.local"},
},
},
},
}, nil
},
},
err: acme.NewErrorISE("error creating ACME policy engine"),
}
},
"fail/isIdentifierAllowed-error": func(t *testing.T) test {
acmeProv := newACMEProv(t)
acmeProv.RequireEAB = true
acc := &acme.Account{ID: "accID"}
fr := &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "zap.internal"},
},
}
b, err := json.Marshal(fr)
assert.FatalError(t, err)
ctx := acme.NewProvisionerContext(context.Background(), acmeProv)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 400,
ca: &mockCA{},
db: &acme.MockDB{
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return &acme.ExternalAccountKey{
Policy: &acme.Policy{
X509: acme.X509Policy{
Allowed: acme.PolicyNames{
DNSNames: []string{"*.local"},
},
},
},
}, nil
},
},
err: acme.NewError(acme.ErrorRejectedIdentifierType, "not authorized"),
}
},
"fail/prov.AuthorizeOrderIdentifier-error": func(t *testing.T) test {
options := &provisioner.Options{
X509: &provisioner.X509Options{
AllowedNames: &policy.X509NameOptions{
DNSDomains: []string{"*.local"},
},
},
}
provWithPolicy := newACMEProvWithOptions(t, options)
provWithPolicy.RequireEAB = true
acc := &acme.Account{ID: "accID"}
fr := &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "zap.internal"},
},
}
b, err := json.Marshal(fr)
assert.FatalError(t, err)
ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 400,
ca: &mockCA{},
db: &acme.MockDB{
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return &acme.ExternalAccountKey{
Policy: &acme.Policy{
X509: acme.X509Policy{
Allowed: acme.PolicyNames{
DNSNames: []string{"*.internal"},
},
},
},
}, nil
},
},
err: acme.NewError(acme.ErrorRejectedIdentifierType, "not authorized"),
}
},
"fail/ca.AreSANsAllowed-error": func(t *testing.T) test {
options := &provisioner.Options{
X509: &provisioner.X509Options{
AllowedNames: &policy.X509NameOptions{
DNSDomains: []string{"*.internal"},
},
},
}
provWithPolicy := newACMEProvWithOptions(t, options)
provWithPolicy.RequireEAB = true
acc := &acme.Account{ID: "accID"}
fr := &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "zap.internal"},
},
}
b, err := json.Marshal(fr)
assert.FatalError(t, err)
ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 400,
ca: &mockCA{
MockAreSANsallowed: func(ctx context.Context, sans []string) error {
return errors.New("force: not authorized by authority")
},
},
db: &acme.MockDB{
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return &acme.ExternalAccountKey{
Policy: &acme.Policy{
X509: acme.X509Policy{
Allowed: acme.PolicyNames{
DNSNames: []string{"*.internal"},
},
},
},
}, nil
},
},
err: acme.NewError(acme.ErrorRejectedIdentifierType, "not authorized"),
}
},
"fail/error-h.newAuthorization": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
fr := &NewOrderRequest{
@ -782,6 +1003,7 @@ func TestHandler_NewOrder(t *testing.T) {
return test{
ctx: ctx,
statusCode: 500,
ca: &mockCA{},
db: &acme.MockDB{
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
assert.Equals(t, ch.AccountID, "accID")
@ -791,6 +1013,11 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, ch.Value, "zap.internal")
return errors.New("force")
},
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, nil
},
},
err: acme.NewErrorISE("error creating challenge: force"),
}
@ -815,6 +1042,7 @@ func TestHandler_NewOrder(t *testing.T) {
return test{
ctx: ctx,
statusCode: 500,
ca: &mockCA{},
db: &acme.MockDB{
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
switch count {
@ -860,6 +1088,11 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
return errors.New("force")
},
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, nil
},
},
err: acme.NewErrorISE("error creating order: force"),
}
@ -886,6 +1119,7 @@ func TestHandler_NewOrder(t *testing.T) {
ctx: ctx,
statusCode: 201,
nor: nor,
ca: &mockCA{},
db: &acme.MockDB{
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
switch chCount {
@ -955,6 +1189,11 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID, *az2ID})
return nil
},
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, nil
},
},
vr: func(t *testing.T, o *acme.Order) {
now := clock.Now()
@ -1000,6 +1239,7 @@ func TestHandler_NewOrder(t *testing.T) {
ctx: ctx,
statusCode: 201,
nor: nor,
ca: &mockCA{},
db: &acme.MockDB{
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
switch count {
@ -1046,6 +1286,11 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
return nil
},
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, nil
},
},
vr: func(t *testing.T, o *acme.Order) {
now := clock.Now()
@ -1091,6 +1336,7 @@ func TestHandler_NewOrder(t *testing.T) {
ctx: ctx,
statusCode: 201,
nor: nor,
ca: &mockCA{},
db: &acme.MockDB{
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
switch count {
@ -1137,6 +1383,11 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
return nil
},
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, nil
},
},
vr: func(t *testing.T, o *acme.Order) {
now := clock.Now()
@ -1181,6 +1432,7 @@ func TestHandler_NewOrder(t *testing.T) {
ctx: ctx,
statusCode: 201,
nor: nor,
ca: &mockCA{},
db: &acme.MockDB{
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
switch count {
@ -1227,6 +1479,11 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
return nil
},
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, nil
},
},
vr: func(t *testing.T, o *acme.Order) {
testBufferDur := 5 * time.Second
@ -1272,6 +1529,7 @@ func TestHandler_NewOrder(t *testing.T) {
ctx: ctx,
statusCode: 201,
nor: nor,
ca: &mockCA{},
db: &acme.MockDB{
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
switch count {
@ -1318,11 +1576,119 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
return nil
},
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, nil
},
},
vr: func(t *testing.T, o *acme.Order) {
testBufferDur := 5 * time.Second
orderExpiry := now.Add(defaultOrderExpiry)
assert.Equals(t, o.ID, "ordID")
assert.Equals(t, o.Status, acme.StatusPending)
assert.Equals(t, o.Identifiers, nor.Identifiers)
assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)})
assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf))
assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf))
assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf))
assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf))
assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry))
assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry))
},
}
},
"ok/default-naf-nbf-with-policy": func(t *testing.T) test {
options := &provisioner.Options{
X509: &provisioner.X509Options{
AllowedNames: &policy.X509NameOptions{
DNSDomains: []string{"*.internal"},
},
},
}
provWithPolicy := newACMEProvWithOptions(t, options)
provWithPolicy.RequireEAB = true
acc := &acme.Account{ID: "accID"}
nor := &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "zap.internal"},
},
}
b, err := json.Marshal(nor)
assert.FatalError(t, err)
ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
var (
ch1, ch2, ch3 **acme.Challenge
az1ID *string
count = 0
)
return test{
ctx: ctx,
statusCode: 201,
nor: nor,
ca: &mockCA{},
db: &acme.MockDB{
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
switch count {
case 0:
ch.ID = "dns"
assert.Equals(t, ch.Type, acme.DNS01)
ch1 = &ch
case 1:
ch.ID = "http"
assert.Equals(t, ch.Type, acme.HTTP01)
ch2 = &ch
case 2:
ch.ID = "tls"
assert.Equals(t, ch.Type, acme.TLSALPN01)
ch3 = &ch
default:
assert.FatalError(t, errors.New("test logic error"))
return errors.New("force")
}
count++
assert.Equals(t, ch.AccountID, "accID")
assert.NotEquals(t, ch.Token, "")
assert.Equals(t, ch.Status, acme.StatusPending)
assert.Equals(t, ch.Value, "zap.internal")
return nil
},
MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error {
az.ID = "az1ID"
az1ID = &az.ID
assert.Equals(t, az.AccountID, "accID")
assert.NotEquals(t, az.Token, "")
assert.Equals(t, az.Status, acme.StatusPending)
assert.Equals(t, az.Identifier, nor.Identifiers[0])
assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3})
assert.Equals(t, az.Wildcard, false)
return nil
},
MockCreateOrder: func(ctx context.Context, o *acme.Order) error {
o.ID = "ordID"
assert.Equals(t, o.AccountID, "accID")
assert.Equals(t, o.ProvisionerID, prov.GetID())
assert.Equals(t, o.Status, acme.StatusPending)
assert.Equals(t, o.Identifiers, nor.Identifiers)
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
return nil
},
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, nil
},
},
vr: func(t *testing.T, o *acme.Order) {
now := clock.Now()
testBufferDur := 5 * time.Second
orderExpiry := now.Add(defaultOrderExpiry)
expNbf := now.Add(-defaultOrderBackdate)
expNaf := now.Add(prov.DefaultTLSCertDuration())
assert.Equals(t, o.ID, "ordID")
assert.Equals(t, o.Status, acme.StatusPending)
assert.Equals(t, o.Identifiers, nor.Identifiers)
@ -1340,6 +1706,7 @@ func TestHandler_NewOrder(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
mockMustAuthority(t, tc.ca)
ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(ctx)
@ -1493,7 +1860,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
db: &acme.MockDB{},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("paylod does not exist"),
err: acme.NewErrorISE("payload does not exist"),
}
},
"fail/unmarshal-payload-error": func(t *testing.T) test {

@ -24,14 +24,16 @@ import (
"github.com/go-chi/chi"
"github.com/google/go-cmp/cmp"
"github.com/pkg/errors"
"golang.org/x/crypto/ocsp"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/x509util"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/x509util"
"golang.org/x/crypto/ocsp"
)
// v is a utility function to return the pointer to an integer
@ -274,14 +276,22 @@ func jwsFinal(sha crypto.Hash, sig []byte, phead, payload string) ([]byte, error
}
type mockCA struct {
MockIsRevoked func(sn string) (bool, error)
MockRevoke func(ctx context.Context, opts *authority.RevokeOptions) error
MockIsRevoked func(sn string) (bool, error)
MockRevoke func(ctx context.Context, opts *authority.RevokeOptions) error
MockAreSANsallowed func(ctx context.Context, sans []string) error
}
func (m *mockCA) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
return nil, nil
}
func (m *mockCA) AreSANsAllowed(ctx context.Context, sans []string) error {
if m.MockAreSANsallowed != nil {
return m.MockAreSANsallowed(ctx, sans)
}
return nil
}
func (m *mockCA) IsRevoked(sn string) (bool, error) {
if m.MockIsRevoked != nil {
return m.MockIsRevoked(sn)

@ -22,6 +22,7 @@ var clock Clock
// CertificateAuthority is the interface implemented by a CA authority.
type CertificateAuthority interface {
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
AreSANsAllowed(ctx context.Context, sans []string) error
IsRevoked(sn string) (bool, error)
Revoke(context.Context, *authority.RevokeOptions) error
LoadProvisionerByName(string) (provisioner.Interface, error)
@ -67,6 +68,7 @@ func PrerequisitesCheckerFromContext(ctx context.Context) (PrerequisitesChecker,
// Provisioner is an interface that implements a subset of the provisioner.Interface --
// only those methods required by the ACME api/authority.
type Provisioner interface {
AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error
AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error)
AuthorizeRevoke(ctx context.Context, token string) error
GetID() string
@ -100,14 +102,15 @@ func MustProvisionerFromContext(ctx context.Context) Provisioner {
// MockProvisioner for testing
type MockProvisioner struct {
Mret1 interface{}
Merr error
MgetID func() string
MgetName func() string
MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
MauthorizeRevoke func(ctx context.Context, token string) error
MdefaultTLSCertDuration func() time.Duration
MgetOptions func() *provisioner.Options
Mret1 interface{}
Merr error
MgetID func() string
MgetName func() string
MauthorizeOrderIdentifier func(ctx context.Context, identifier provisioner.ACMEIdentifier) error
MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
MauthorizeRevoke func(ctx context.Context, token string) error
MdefaultTLSCertDuration func() time.Duration
MgetOptions func() *provisioner.Options
}
// GetName mock
@ -118,6 +121,14 @@ func (m *MockProvisioner) GetName() string {
return m.Mret1.(string)
}
// AuthorizeOrderIdentifiers mock
func (m *MockProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error {
if m.MauthorizeOrderIdentifier != nil {
return m.MauthorizeOrderIdentifier(ctx, identifier)
}
return m.Merr
}
// AuthorizeSign mock
func (m *MockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
if m.MauthorizeSign != nil {

@ -23,6 +23,7 @@ type DB interface {
GetExternalAccountKey(ctx context.Context, provisionerID, keyID string) (*ExternalAccountKey, error)
GetExternalAccountKeys(ctx context.Context, provisionerID, cursor string, limit int) ([]*ExternalAccountKey, string, error)
GetExternalAccountKeyByReference(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error)
GetExternalAccountKeyByAccountID(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error)
DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error
UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *ExternalAccountKey) error
@ -83,6 +84,7 @@ type MockDB struct {
MockGetExternalAccountKey func(ctx context.Context, provisionerID, keyID string) (*ExternalAccountKey, error)
MockGetExternalAccountKeys func(ctx context.Context, provisionerID, cursor string, limit int) ([]*ExternalAccountKey, string, error)
MockGetExternalAccountKeyByReference func(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error)
MockGetExternalAccountKeyByAccountID func(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error)
MockDeleteExternalAccountKey func(ctx context.Context, provisionerID, keyID string) error
MockUpdateExternalAccountKey func(ctx context.Context, provisionerID string, eak *ExternalAccountKey) error
@ -191,6 +193,16 @@ func (m *MockDB) GetExternalAccountKeyByReference(ctx context.Context, provision
return m.MockRet1.(*ExternalAccountKey), m.MockError
}
// GetExternalAccountKeyByAccountID mock
func (m *MockDB) GetExternalAccountKeyByAccountID(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error) {
if m.MockGetExternalAccountKeyByAccountID != nil {
return m.MockGetExternalAccountKeyByAccountID(ctx, provisionerID, accountID)
} else if m.MockError != nil {
return nil, m.MockError
}
return m.MockRet1.(*ExternalAccountKey), m.MockError
}
// DeleteExternalAccountKey mock
func (m *MockDB) DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error {
if m.MockDeleteExternalAccountKey != nil {

@ -8,6 +8,7 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
nosqlDB "github.com/smallstep/nosql"
)
@ -23,7 +24,7 @@ type dbExternalAccountKey struct {
ProvisionerID string `json:"provisionerID"`
Reference string `json:"reference"`
AccountID string `json:"accountID,omitempty"`
KeyBytes []byte `json:"key"`
HmacKey []byte `json:"key"`
CreatedAt time.Time `json:"createdAt"`
BoundAt time.Time `json:"boundAt"`
}
@ -72,7 +73,7 @@ func (db *DB) CreateExternalAccountKey(ctx context.Context, provisionerID, refer
ID: keyID,
ProvisionerID: provisionerID,
Reference: reference,
KeyBytes: random,
HmacKey: random,
CreatedAt: clock.Now(),
}
@ -99,7 +100,7 @@ func (db *DB) CreateExternalAccountKey(ctx context.Context, provisionerID, refer
ProvisionerID: dbeak.ProvisionerID,
Reference: dbeak.Reference,
AccountID: dbeak.AccountID,
KeyBytes: dbeak.KeyBytes,
HmacKey: dbeak.HmacKey,
CreatedAt: dbeak.CreatedAt,
BoundAt: dbeak.BoundAt,
}, nil
@ -124,7 +125,7 @@ func (db *DB) GetExternalAccountKey(ctx context.Context, provisionerID, keyID st
ProvisionerID: dbeak.ProvisionerID,
Reference: dbeak.Reference,
AccountID: dbeak.AccountID,
KeyBytes: dbeak.KeyBytes,
HmacKey: dbeak.HmacKey,
CreatedAt: dbeak.CreatedAt,
BoundAt: dbeak.BoundAt,
}, nil
@ -191,7 +192,7 @@ func (db *DB) GetExternalAccountKeys(ctx context.Context, provisionerID, cursor
}
keys = append(keys, &acme.ExternalAccountKey{
ID: eak.ID,
KeyBytes: eak.KeyBytes,
HmacKey: eak.HmacKey,
ProvisionerID: eak.ProvisionerID,
Reference: eak.Reference,
AccountID: eak.AccountID,
@ -226,6 +227,10 @@ func (db *DB) GetExternalAccountKeyByReference(ctx context.Context, provisionerI
return db.GetExternalAccountKey(ctx, provisionerID, dbExternalAccountKeyReference.ExternalAccountKeyID)
}
func (db *DB) GetExternalAccountKeyByAccountID(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
return nil, nil
}
func (db *DB) UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
externalAccountKeyMutex.Lock()
defer externalAccountKeyMutex.Unlock()
@ -252,7 +257,7 @@ func (db *DB) UpdateExternalAccountKey(ctx context.Context, provisionerID string
ProvisionerID: eak.ProvisionerID,
Reference: eak.Reference,
AccountID: eak.AccountID,
KeyBytes: eak.KeyBytes,
HmacKey: eak.HmacKey,
CreatedAt: eak.CreatedAt,
BoundAt: eak.BoundAt,
}

@ -8,6 +8,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
certdb "github.com/smallstep/certificates/db"
@ -32,7 +33,7 @@ func TestDB_getDBExternalAccountKey(t *testing.T) {
ProvisionerID: provID,
Reference: "ref",
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
b, err := json.Marshal(dbeak)
@ -108,7 +109,7 @@ func TestDB_getDBExternalAccountKey(t *testing.T) {
}
} else if assert.Nil(t, tc.err) {
assert.Equals(t, dbeak.ID, tc.dbeak.ID)
assert.Equals(t, dbeak.KeyBytes, tc.dbeak.KeyBytes)
assert.Equals(t, dbeak.HmacKey, tc.dbeak.HmacKey)
assert.Equals(t, dbeak.ProvisionerID, tc.dbeak.ProvisionerID)
assert.Equals(t, dbeak.Reference, tc.dbeak.Reference)
assert.Equals(t, dbeak.CreatedAt, tc.dbeak.CreatedAt)
@ -136,7 +137,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) {
ProvisionerID: provID,
Reference: "ref",
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
b, err := json.Marshal(dbeak)
@ -154,7 +155,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) {
ProvisionerID: provID,
Reference: "ref",
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
},
}
@ -179,7 +180,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) {
ProvisionerID: "aDifferentProvID",
Reference: "ref",
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
b, err := json.Marshal(dbeak)
@ -197,7 +198,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) {
ProvisionerID: provID,
Reference: "ref",
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
},
acmeErr: acme.NewError(acme.ErrorUnauthorizedType, "provisioner does not match provisioner for which the EAB key was created"),
@ -225,7 +226,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) {
}
} else if assert.Nil(t, tc.err) {
assert.Equals(t, eak.ID, tc.eak.ID)
assert.Equals(t, eak.KeyBytes, tc.eak.KeyBytes)
assert.Equals(t, eak.HmacKey, tc.eak.HmacKey)
assert.Equals(t, eak.ProvisionerID, tc.eak.ProvisionerID)
assert.Equals(t, eak.Reference, tc.eak.Reference)
assert.Equals(t, eak.CreatedAt, tc.eak.CreatedAt)
@ -255,7 +256,7 @@ func TestDB_GetExternalAccountKeyByReference(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
dbref := &dbExternalAccountKeyReference{
@ -288,7 +289,7 @@ func TestDB_GetExternalAccountKeyByReference(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
},
err: nil,
@ -392,7 +393,7 @@ func TestDB_GetExternalAccountKeyByReference(t *testing.T) {
assert.Equals(t, eak.AccountID, tc.eak.AccountID)
assert.Equals(t, eak.BoundAt, tc.eak.BoundAt)
assert.Equals(t, eak.CreatedAt, tc.eak.CreatedAt)
assert.Equals(t, eak.KeyBytes, tc.eak.KeyBytes)
assert.Equals(t, eak.HmacKey, tc.eak.HmacKey)
assert.Equals(t, eak.ProvisionerID, tc.eak.ProvisionerID)
assert.Equals(t, eak.Reference, tc.eak.Reference)
}
@ -420,7 +421,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
b1, err := json.Marshal(dbeak1)
@ -430,7 +431,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
b2, err := json.Marshal(dbeak2)
@ -440,7 +441,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
ProvisionerID: "aDifferentProvID",
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
b3, err := json.Marshal(dbeak3)
@ -513,7 +514,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
},
{
@ -521,7 +522,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
},
},
@ -598,7 +599,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
assert.Equals(t, "", nextCursor)
for i, eak := range eaks {
assert.Equals(t, eak.ID, tc.eaks[i].ID)
assert.Equals(t, eak.KeyBytes, tc.eaks[i].KeyBytes)
assert.Equals(t, eak.HmacKey, tc.eaks[i].HmacKey)
assert.Equals(t, eak.ProvisionerID, tc.eaks[i].ProvisionerID)
assert.Equals(t, eak.Reference, tc.eaks[i].Reference)
assert.Equals(t, eak.CreatedAt, tc.eaks[i].CreatedAt)
@ -627,7 +628,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
dbref := &dbExternalAccountKeyReference{
@ -707,7 +708,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
ProvisionerID: "aDifferentProvID",
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
b, err := json.Marshal(dbeak)
@ -730,7 +731,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
dbref := &dbExternalAccountKeyReference{
@ -780,7 +781,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
dbref := &dbExternalAccountKeyReference{
@ -830,7 +831,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
dbref := &dbExternalAccountKeyReference{
@ -953,7 +954,7 @@ func TestDB_CreateExternalAccountKey(t *testing.T) {
assert.Equals(t, string(key), dbeak.ID)
assert.Equals(t, eak.ProvisionerID, dbeak.ProvisionerID)
assert.Equals(t, eak.Reference, dbeak.Reference)
assert.Equals(t, 32, len(dbeak.KeyBytes))
assert.Equals(t, 32, len(dbeak.HmacKey))
assert.False(t, dbeak.CreatedAt.IsZero())
assert.Equals(t, dbeak.AccountID, eak.AccountID)
assert.True(t, dbeak.BoundAt.IsZero())
@ -1078,7 +1079,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
b, err := json.Marshal(dbeak)
@ -1096,7 +1097,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
return test{
@ -1120,7 +1121,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
assert.Equals(t, dbNew.AccountID, dbeak.AccountID)
assert.Equals(t, dbNew.CreatedAt, dbeak.CreatedAt)
assert.Equals(t, dbNew.BoundAt, dbeak.BoundAt)
assert.Equals(t, dbNew.KeyBytes, dbeak.KeyBytes)
assert.Equals(t, dbNew.HmacKey, dbeak.HmacKey)
return nu, true, nil
},
},
@ -1148,7 +1149,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
ProvisionerID: "aDifferentProvID",
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
b, err := json.Marshal(newDBEAK)
@ -1174,7 +1175,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
b, err := json.Marshal(newDBEAK)
@ -1200,7 +1201,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
ProvisionerID: provID,
Reference: ref,
AccountID: "",
KeyBytes: []byte{1, 3, 3, 7},
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now,
}
b, err := json.Marshal(newDBEAK)
@ -1237,7 +1238,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
assert.Equals(t, dbeak.AccountID, tc.eak.AccountID)
assert.Equals(t, dbeak.CreatedAt, tc.eak.CreatedAt)
assert.Equals(t, dbeak.BoundAt, tc.eak.BoundAt)
assert.Equals(t, dbeak.KeyBytes, tc.eak.KeyBytes)
assert.Equals(t, dbeak.HmacKey, tc.eak.HmacKey)
}
})
}

@ -268,6 +268,7 @@ func TestOrder_UpdateStatus(t *testing.T) {
type mockSignAuth struct {
sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
areSANsAllowed func(ctx context.Context, sans []string) error
loadProvisionerByName func(string) (provisioner.Interface, error)
ret1, ret2 interface{}
err error
@ -282,6 +283,13 @@ func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.S
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
}
func (m *mockSignAuth) AreSANsAllowed(ctx context.Context, sans []string) error {
if m.areSANsAllowed != nil {
return m.areSANsAllowed(ctx, sans)
}
return m.err
}
func (m *mockSignAuth) LoadProvisionerByName(name string) (provisioner.Interface, error) {
if m.loadProvisionerByName != nil {
return m.loadProvisionerByName(name)

@ -3,16 +3,20 @@ package read
import (
"encoding/json"
"errors"
"io"
"net/http"
"strings"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/errs"
)
// JSON reads JSON from the request body and stores it in the value
// pointed by v.
// pointed to by v.
func JSON(r io.Reader, v interface{}) error {
if err := json.NewDecoder(r).Decode(v); err != nil {
return errs.BadRequestErr(err, "error decoding json")
@ -21,11 +25,42 @@ func JSON(r io.Reader, v interface{}) error {
}
// ProtoJSON reads JSON from the request body and stores it in the value
// pointed by v.
// pointed to by m.
func ProtoJSON(r io.Reader, m proto.Message) error {
data, err := io.ReadAll(r)
if err != nil {
return errs.BadRequestErr(err, "error reading request body")
}
return protojson.Unmarshal(data, m)
switch err := protojson.Unmarshal(data, m); {
case errors.Is(err, proto.Error):
return badProtoJSONError(err.Error())
default:
return err
}
}
// badProtoJSONError is an error type that is returned by ProtoJSON
// when a proto message cannot be unmarshaled. Usually this is caused
// by an error in the request body.
type badProtoJSONError string
// Error implements error for badProtoJSONError
func (e badProtoJSONError) Error() string {
return string(e)
}
// Render implements render.RenderableError for badProtoJSONError
func (e badProtoJSONError) Render(w http.ResponseWriter) {
v := struct {
Type string `json:"type"`
Detail string `json:"detail"`
Message string `json:"message"`
}{
Type: "badRequest",
Detail: "bad request",
// trim the proto prefix for the message
Message: strings.TrimSpace(strings.TrimPrefix(e.Error(), "proto:")),
}
render.JSONStatus(w, v, http.StatusBadRequest)
}

@ -1,10 +1,21 @@
package read
import (
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"testing/iotest"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/errs"
)
@ -44,3 +55,110 @@ func TestJSON(t *testing.T) {
})
}
}
func TestProtoJSON(t *testing.T) {
p := new(linkedca.Policy) // TODO(hs): can we use something different, so we don't need the import?
type args struct {
r io.Reader
m proto.Message
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "fail/io.ReadAll",
args: args{
r: iotest.ErrReader(errors.New("read error")),
m: p,
},
wantErr: true,
},
{
name: "fail/proto",
args: args{
r: strings.NewReader(`{?}`),
m: p,
},
wantErr: true,
},
{
name: "ok",
args: args{
r: strings.NewReader(`{"x509":{}}`),
m: p,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ProtoJSON(tt.args.r, tt.args.m)
if (err != nil) != tt.wantErr {
t.Errorf("ProtoJSON() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr {
switch err.(type) {
case badProtoJSONError:
assert.Contains(t, err.Error(), "syntax error")
case *errs.Error:
var ee *errs.Error
if errors.As(err, &ee) {
assert.Equal(t, http.StatusBadRequest, ee.Status)
}
}
return
}
assert.Equal(t, protoreflect.FullName("linkedca.Policy"), proto.MessageName(tt.args.m))
assert.True(t, proto.Equal(&linkedca.Policy{X509: &linkedca.X509Policy{}}, tt.args.m))
})
}
}
func Test_badProtoJSONError_Render(t *testing.T) {
tests := []struct {
name string
e badProtoJSONError
expected string
}{
{
name: "bad proto normal space",
e: badProtoJSONError("proto: syntax error (line 1:2): invalid value ?"),
expected: "syntax error (line 1:2): invalid value ?",
},
{
name: "bad proto non breaking space",
e: badProtoJSONError("proto: syntax error (line 1:2): invalid value ?"),
expected: "syntax error (line 1:2): invalid value ?",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
tt.e.Render(w)
res := w.Result()
defer res.Body.Close()
data, err := io.ReadAll(res.Body)
assert.NoError(t, err)
v := struct {
Type string `json:"type"`
Detail string `json:"detail"`
Message string `json:"message"`
}{}
assert.NoError(t, json.Unmarshal(data, &v))
assert.Equal(t, "badRequest", v.Type)
assert.Equal(t, "bad request", v.Detail)
assert.Equal(t, "syntax error (line 1:2): invalid value ?", v.Message)
})
}
}

@ -1,22 +1,15 @@
package api
import (
"context"
"fmt"
"net/http"
"github.com/go-chi/chi"
"go.step.sm/linkedca"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner"
)
const (
// provisionerContextKey provisioner key
provisionerContextKey = ContextKey("provisioner")
)
// CreateExternalAccountKeyRequest is the type for POST /admin/acme/eab requests
@ -40,55 +33,24 @@ type GetExternalAccountKeysResponse struct {
// requireEABEnabled is a middleware that ensures ACME EAB is enabled
// before serving requests that act on ACME EAB credentials.
func requireEABEnabled(next nextHTTP) nextHTTP {
func requireEABEnabled(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
provName := chi.URLParam(r, "provisionerName")
eabEnabled, prov, err := provisionerHasEABEnabled(ctx, provName)
if err != nil {
render.Error(w, err)
prov := linkedca.MustProvisionerFromContext(ctx)
acmeProvisioner := prov.GetDetails().GetACME()
if acmeProvisioner == nil {
render.Error(w, admin.NewErrorISE("error getting ACME details for provisioner '%s'", prov.GetName()))
return
}
if !eabEnabled {
render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName()))
if !acmeProvisioner.RequireEab {
render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner '%s'", prov.GetName()))
return
}
ctx = context.WithValue(ctx, provisionerContextKey, prov)
next(w, r.WithContext(ctx))
}
}
// provisionerHasEABEnabled determines if the "requireEAB" setting for an ACME
// provisioner is set to true and thus has EAB enabled.
func provisionerHasEABEnabled(ctx context.Context, provisionerName string) (bool, *linkedca.Provisioner, error) {
var (
p provisioner.Interface
err error
)
auth := mustAuthority(ctx)
db := admin.MustFromContext(ctx)
if p, err = auth.LoadProvisionerByName(provisionerName); err != nil {
return false, nil, admin.WrapErrorISE(err, "error loading provisioner %s", provisionerName)
next(w, r)
}
prov, err := db.GetProvisioner(ctx, p.GetID())
if err != nil {
return false, nil, admin.WrapErrorISE(err, "error getting provisioner with ID: %s", p.GetID())
}
details := prov.GetDetails()
if details == nil {
return false, nil, admin.NewErrorISE("error getting details for provisioner with ID: %s", p.GetID())
}
acmeProvisioner := details.GetACME()
if acmeProvisioner == nil {
return false, nil, admin.NewErrorISE("error getting ACME details for provisioner with ID: %s", p.GetID())
}
return acmeProvisioner.GetRequireEab(), prov, nil
}
type acmeAdminResponderInterface interface {
@ -119,3 +81,72 @@ func (h *ACMEAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *
func (h *ACMEAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
}
func eakToLinked(k *acme.ExternalAccountKey) *linkedca.EABKey {
if k == nil {
return nil
}
eak := &linkedca.EABKey{
Id: k.ID,
HmacKey: k.HmacKey,
Provisioner: k.ProvisionerID,
Reference: k.Reference,
Account: k.AccountID,
CreatedAt: timestamppb.New(k.CreatedAt),
BoundAt: timestamppb.New(k.BoundAt),
}
if k.Policy != nil {
eak.Policy = &linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{},
Deny: &linkedca.X509Names{},
},
}
eak.Policy.X509.Allow.Dns = k.Policy.X509.Allowed.DNSNames
eak.Policy.X509.Allow.Ips = k.Policy.X509.Allowed.IPRanges
eak.Policy.X509.Deny.Dns = k.Policy.X509.Denied.DNSNames
eak.Policy.X509.Deny.Ips = k.Policy.X509.Denied.IPRanges
eak.Policy.X509.AllowWildcardNames = k.Policy.X509.AllowWildcardNames
}
return eak
}
func linkedEAKToCertificates(k *linkedca.EABKey) *acme.ExternalAccountKey {
if k == nil {
return nil
}
eak := &acme.ExternalAccountKey{
ID: k.Id,
ProvisionerID: k.Provisioner,
Reference: k.Reference,
AccountID: k.Account,
HmacKey: k.HmacKey,
CreatedAt: k.CreatedAt.AsTime(),
BoundAt: k.BoundAt.AsTime(),
}
if policy := k.GetPolicy(); policy != nil {
eak.Policy = &acme.Policy{}
if x509 := policy.GetX509(); x509 != nil {
eak.Policy.X509 = acme.X509Policy{}
if allow := x509.GetAllow(); allow != nil {
eak.Policy.X509.Allowed = acme.PolicyNames{}
eak.Policy.X509.Allowed.DNSNames = allow.Dns
eak.Policy.X509.Allowed.IPRanges = allow.Ips
}
if deny := x509.GetDeny(); deny != nil {
eak.Policy.X509.Denied = acme.PolicyNames{}
eak.Policy.X509.Denied.DNSNames = deny.Dns
eak.Policy.X509.Denied.IPRanges = deny.Ips
}
eak.Policy.X509.AllowWildcardNames = x509.AllowWildcardNames
}
}
return eak
}

@ -4,20 +4,24 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"time"
"github.com/go-chi/chi"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/linkedca"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
"go.step.sm/linkedca"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/admin"
)
func readProtoJSON(r io.ReadCloser, m proto.Message) error {
@ -43,107 +47,76 @@ func mockMustAuthority(t *testing.T, a adminAuthority) {
func TestHandler_requireEABEnabled(t *testing.T) {
type test struct {
ctx context.Context
adminDB admin.DB
auth adminAuthority
next nextHTTP
next http.HandlerFunc
err *admin.Error
statusCode int
}
var tests = map[string]func(t *testing.T) test{
"fail/h.provisionerHasEABEnabled": func(t *testing.T) test {
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("provisionerName", "provName")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return nil, errors.New("force")
},
"fail/prov.GetDetails": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
Name: "provName",
}
err := admin.NewErrorISE("error loading provisioner provName: force")
err.Message = "error loading provisioner provName: force"
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
err := admin.NewErrorISE("error getting ACME details for provisioner 'provName'")
err.Message = "error getting ACME details for provisioner 'provName'"
return test{
ctx: ctx,
auth: auth,
adminDB: &admin.MockDB{},
err: err,
statusCode: 500,
}
},
"ok/eab-disabled": func(t *testing.T) test {
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("provisionerName", "provName")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return &provisioner.MockProvisioner{
MgetID: func() string {
return "provID"
},
}, nil
},
"fail/prov.GetDetails.GetACME": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
Name: "provName",
Details: &linkedca.ProvisionerDetails{},
}
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
err := admin.NewErrorISE("error getting ACME details for provisioner 'provName'")
err.Message = "error getting ACME details for provisioner 'provName'"
return test{
ctx: ctx,
err: err,
statusCode: 500,
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return &linkedca.Provisioner{
Id: "provID",
Name: "provName",
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_ACME{
ACME: &linkedca.ACMEProvisioner{
RequireEab: false,
},
},
},
"ok/eab-disabled": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
Name: "provName",
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_ACME{
ACME: &linkedca.ACMEProvisioner{
RequireEab: false,
},
}, nil
},
},
}
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
err := admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner provName")
err.Message = "ACME EAB not enabled for provisioner provName"
err.Message = "ACME EAB not enabled for provisioner 'provName'"
return test{
ctx: ctx,
auth: auth,
adminDB: db,
err: err,
statusCode: 400,
}
},
"ok/eab-enabled": func(t *testing.T) test {
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("provisionerName", "provName")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return &provisioner.MockProvisioner{
MgetID: func() string {
return "provID"
prov := &linkedca.Provisioner{
Id: "provID",
Name: "provName",
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_ACME{
ACME: &linkedca.ACMEProvisioner{
RequireEab: true,
},
}, nil
},
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return &linkedca.Provisioner{
Id: "provID",
Name: "provName",
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_ACME{
ACME: &linkedca.ACMEProvisioner{
RequireEab: true,
},
},
},
}, nil
},
},
}
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
return test{
ctx: ctx,
auth: auth,
adminDB: db,
ctx: ctx,
next: func(w http.ResponseWriter, r *http.Request) {
w.Write(nil) // mock response with status 200
},
@ -155,10 +128,7 @@ func TestHandler_requireEABEnabled(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
mockMustAuthority(t, tc.auth)
ctx := admin.NewContext(tc.ctx, tc.adminDB)
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
req = req.WithContext(ctx)
req := httptest.NewRequest("GET", "/foo", nil).WithContext(tc.ctx)
w := httptest.NewRecorder()
requireEABEnabled(tc.next)(w, req)
res := w.Result()
@ -184,214 +154,6 @@ func TestHandler_requireEABEnabled(t *testing.T) {
}
}
func TestHandler_provisionerHasEABEnabled(t *testing.T) {
type test struct {
adminDB admin.DB
auth adminAuthority
provisionerName string
want bool
err *admin.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/auth.LoadProvisionerByName": func(t *testing.T) test {
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return nil, errors.New("force")
},
}
return test{
auth: auth,
adminDB: &admin.MockDB{},
provisionerName: "provName",
want: false,
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
}
},
"fail/db.GetProvisioner": func(t *testing.T) test {
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return &provisioner.MockProvisioner{
MgetID: func() string {
return "provID"
},
}, nil
},
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return nil, errors.New("force")
},
}
return test{
auth: auth,
adminDB: db,
provisionerName: "provName",
want: false,
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
}
},
"fail/prov.GetDetails": func(t *testing.T) test {
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return &provisioner.MockProvisioner{
MgetID: func() string {
return "provID"
},
}, nil
},
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return &linkedca.Provisioner{
Id: "provID",
Name: "provName",
Details: nil,
}, nil
},
}
return test{
auth: auth,
adminDB: db,
provisionerName: "provName",
want: false,
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
}
},
"fail/details.GetACME": func(t *testing.T) test {
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return &provisioner.MockProvisioner{
MgetID: func() string {
return "provID"
},
}, nil
},
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return &linkedca.Provisioner{
Id: "provID",
Name: "provName",
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_ACME{
ACME: nil,
},
},
}, nil
},
}
return test{
auth: auth,
adminDB: db,
provisionerName: "provName",
want: false,
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
}
},
"ok/eab-disabled": func(t *testing.T) test {
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "eab-disabled", name)
return &provisioner.MockProvisioner{
MgetID: func() string {
return "provID"
},
}, nil
},
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return &linkedca.Provisioner{
Id: "provID",
Name: "eab-disabled",
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_ACME{
ACME: &linkedca.ACMEProvisioner{
RequireEab: false,
},
},
},
}, nil
},
}
return test{
adminDB: db,
auth: auth,
provisionerName: "eab-disabled",
want: false,
}
},
"ok/eab-enabled": func(t *testing.T) test {
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "eab-enabled", name)
return &provisioner.MockProvisioner{
MgetID: func() string {
return "provID"
},
}, nil
},
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return &linkedca.Provisioner{
Id: "provID",
Name: "eab-enabled",
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_ACME{
ACME: &linkedca.ACMEProvisioner{
RequireEab: true,
},
},
},
}, nil
},
}
return test{
adminDB: db,
auth: auth,
provisionerName: "eab-enabled",
want: true,
}
},
}
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
mockMustAuthority(t, tc.auth)
ctx := admin.NewContext(context.Background(), tc.adminDB)
got, prov, err := provisionerHasEABEnabled(ctx, tc.provisionerName)
if (err != nil) != (tc.err != nil) {
t.Errorf("Handler.provisionerHasEABEnabled() error = %v, want err %v", err, tc.err)
return
}
if tc.err != nil {
assert.Type(t, &linkedca.Provisioner{}, prov)
assert.Type(t, &admin.Error{}, err)
adminError, _ := err.(*admin.Error)
assert.Equals(t, tc.err.Type, adminError.Type)
assert.Equals(t, tc.err.Status, adminError.Status)
assert.Equals(t, tc.err.StatusCode(), adminError.StatusCode())
assert.Equals(t, tc.err.Message, adminError.Message)
assert.Equals(t, tc.err.Detail, adminError.Detail)
return
}
if got != tc.want {
t.Errorf("Handler.provisionerHasEABEnabled() = %v, want %v", got, tc.want)
}
})
}
}
func TestCreateExternalAccountKeyRequest_Validate(t *testing.T) {
type fields struct {
Reference string
@ -591,3 +353,206 @@ func TestHandler_GetExternalAccountKeys(t *testing.T) {
})
}
}
func Test_eakToLinked(t *testing.T) {
tests := []struct {
name string
k *acme.ExternalAccountKey
want *linkedca.EABKey
}{
{
name: "no-key",
k: nil,
want: nil,
},
{
name: "no-policy",
k: &acme.ExternalAccountKey{
ID: "keyID",
ProvisionerID: "provID",
Reference: "ref",
AccountID: "accID",
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour),
BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC),
Policy: nil,
},
want: &linkedca.EABKey{
Id: "keyID",
Provisioner: "provID",
HmacKey: []byte{1, 3, 3, 7},
Reference: "ref",
Account: "accID",
CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)),
BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)),
Policy: nil,
},
},
{
name: "with-policy",
k: &acme.ExternalAccountKey{
ID: "keyID",
ProvisionerID: "provID",
Reference: "ref",
AccountID: "accID",
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour),
BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC),
Policy: &acme.Policy{
X509: acme.X509Policy{
Allowed: acme.PolicyNames{
DNSNames: []string{"*.local"},
IPRanges: []string{"10.0.0.0/24"},
},
Denied: acme.PolicyNames{
DNSNames: []string{"badhost.local"},
IPRanges: []string{"10.0.0.30"},
},
},
},
},
want: &linkedca.EABKey{
Id: "keyID",
Provisioner: "provID",
HmacKey: []byte{1, 3, 3, 7},
Reference: "ref",
Account: "accID",
CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)),
BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)),
Policy: &linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{
Dns: []string{"*.local"},
Ips: []string{"10.0.0.0/24"},
},
Deny: &linkedca.X509Names{
Dns: []string{"badhost.local"},
Ips: []string{"10.0.0.30"},
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := eakToLinked(tt.k); !reflect.DeepEqual(got, tt.want) {
t.Errorf("eakToLinked() = %v, want %v", got, tt.want)
}
})
}
}
func Test_linkedEAKToCertificates(t *testing.T) {
tests := []struct {
name string
k *linkedca.EABKey
want *acme.ExternalAccountKey
}{
{
name: "no-key",
k: nil,
want: nil,
},
{
name: "no-policy",
k: &linkedca.EABKey{
Id: "keyID",
Provisioner: "provID",
HmacKey: []byte{1, 3, 3, 7},
Reference: "ref",
Account: "accID",
CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)),
BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)),
Policy: nil,
},
want: &acme.ExternalAccountKey{
ID: "keyID",
ProvisionerID: "provID",
Reference: "ref",
AccountID: "accID",
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour),
BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC),
Policy: nil,
},
},
{
name: "no-x509-policy",
k: &linkedca.EABKey{
Id: "keyID",
Provisioner: "provID",
HmacKey: []byte{1, 3, 3, 7},
Reference: "ref",
Account: "accID",
CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)),
BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)),
Policy: &linkedca.Policy{},
},
want: &acme.ExternalAccountKey{
ID: "keyID",
ProvisionerID: "provID",
Reference: "ref",
AccountID: "accID",
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour),
BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC),
Policy: &acme.Policy{},
},
},
{
name: "with-x509-policy",
k: &linkedca.EABKey{
Id: "keyID",
Provisioner: "provID",
HmacKey: []byte{1, 3, 3, 7},
Reference: "ref",
Account: "accID",
CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)),
BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)),
Policy: &linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{
Dns: []string{"*.local"},
Ips: []string{"10.0.0.0/24"},
},
Deny: &linkedca.X509Names{
Dns: []string{"badhost.local"},
Ips: []string{"10.0.0.30"},
},
AllowWildcardNames: true,
},
},
},
want: &acme.ExternalAccountKey{
ID: "keyID",
ProvisionerID: "provID",
Reference: "ref",
AccountID: "accID",
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour),
BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC),
Policy: &acme.Policy{
X509: acme.X509Policy{
Allowed: acme.PolicyNames{
DNSNames: []string{"*.local"},
IPRanges: []string{"10.0.0.0/24"},
},
Denied: acme.PolicyNames{
DNSNames: []string{"badhost.local"},
IPRanges: []string{"10.0.0.30"},
},
AllowWildcardNames: true,
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := linkedEAKToCertificates(tt.k); !reflect.DeepEqual(got, tt.want) {
t.Errorf("linkedEAKToCertificates() = %v, want %v", got, tt.want)
}
})
}
}

@ -29,6 +29,10 @@ type adminAuthority interface {
LoadProvisionerByID(id string) (provisioner.Interface, error)
UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error
RemoveProvisioner(ctx context.Context, id string) error
GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error)
CreateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error)
UpdateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error)
RemoveAuthorityPolicy(ctx context.Context) error
}
// CreateAdminRequest represents the body for a CreateAdmin request.

@ -14,11 +14,13 @@ import (
"github.com/go-chi/chi"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/protobuf/types/known/timestamppb"
"go.step.sm/linkedca"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/linkedca"
"google.golang.org/protobuf/types/known/timestamppb"
)
type mockAdminAuthority struct {
@ -37,6 +39,11 @@ type mockAdminAuthority struct {
MockLoadProvisionerByID func(id string) (provisioner.Interface, error)
MockUpdateProvisioner func(ctx context.Context, nu *linkedca.Provisioner) error
MockRemoveProvisioner func(ctx context.Context, id string) error
MockGetAuthorityPolicy func(ctx context.Context) (*linkedca.Policy, error)
MockCreateAuthorityPolicy func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error)
MockUpdateAuthorityPolicy func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error)
MockRemoveAuthorityPolicy func(ctx context.Context) error
}
func (m *mockAdminAuthority) IsAdminAPIEnabled() bool {
@ -130,6 +137,34 @@ func (m *mockAdminAuthority) RemoveProvisioner(ctx context.Context, id string) e
return m.MockErr
}
func (m *mockAdminAuthority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
if m.MockGetAuthorityPolicy != nil {
return m.MockGetAuthorityPolicy(ctx)
}
return m.MockRet1.(*linkedca.Policy), m.MockErr
}
func (m *mockAdminAuthority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) {
if m.MockCreateAuthorityPolicy != nil {
return m.MockCreateAuthorityPolicy(ctx, adm, policy)
}
return m.MockRet1.(*linkedca.Policy), m.MockErr
}
func (m *mockAdminAuthority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) {
if m.MockUpdateAuthorityPolicy != nil {
return m.MockUpdateAuthorityPolicy(ctx, adm, policy)
}
return m.MockRet1.(*linkedca.Policy), m.MockErr
}
func (m *mockAdminAuthority) RemoveAuthorityPolicy(ctx context.Context) error {
if m.MockRemoveAuthorityPolicy != nil {
return m.MockRemoveAuthorityPolicy(ctx)
}
return m.MockErr
}
func TestCreateAdminRequest_Validate(t *testing.T) {
type fields struct {
Subject string

@ -2,6 +2,7 @@ package api
import (
"context"
"net/http"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
@ -11,22 +12,22 @@ import (
// Handler is the Admin API request handler.
type Handler struct {
acmeResponder acmeAdminResponderInterface
acmeResponder acmeAdminResponderInterface
policyResponder policyAdminResponderInterface
}
// Route traffic and implement the Router interface.
//
// Deprecated: use Route(r api.Router, acmeResponder acmeAdminResponderInterface)
func (h *Handler) Route(r api.Router) {
Route(r, h.acmeResponder)
Route(r, h.acmeResponder, h.policyResponder)
}
// NewHandler returns a new Authority Config Handler.
//
// Deprecated: use Route(r api.Router, acmeResponder acmeAdminResponderInterface)
func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder acmeAdminResponderInterface) api.RouterHandler {
func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder acmeAdminResponderInterface, policyResponder policyAdminResponderInterface) api.RouterHandler {
return &Handler{
acmeResponder: acmeResponder,
acmeResponder: acmeResponder,
policyResponder: policyResponder,
}
}
@ -35,11 +36,35 @@ var mustAuthority = func(ctx context.Context) adminAuthority {
}
// Route traffic and implement the Router interface.
func Route(r api.Router, acmeResponder acmeAdminResponderInterface) {
authnz := func(next nextHTTP) nextHTTP {
func Route(r api.Router, acmeResponder acmeAdminResponderInterface, policyResponder policyAdminResponderInterface) {
authnz := func(next http.HandlerFunc) http.HandlerFunc {
return extractAuthorizeTokenAdmin(requireAPIEnabled(next))
}
enabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc {
return checkAction(next, true)
}
disabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc {
return checkAction(next, false)
}
acmeEABMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
return authnz(loadProvisionerByName(requireEABEnabled(next)))
}
authorityPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
return authnz(enabledInStandalone(next))
}
provisionerPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
return authnz(disabledInStandalone(loadProvisionerByName(next)))
}
acmePolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
return authnz(disabledInStandalone(loadProvisionerByName(requireEABEnabled(loadExternalAccountKey(next)))))
}
// Provisioners
r.MethodFunc("GET", "/provisioners/{name}", authnz(GetProvisioner))
r.MethodFunc("GET", "/provisioners", authnz(GetProvisioners))
@ -55,8 +80,31 @@ func Route(r api.Router, acmeResponder acmeAdminResponderInterface) {
r.MethodFunc("DELETE", "/admins/{id}", authnz(DeleteAdmin))
// ACME External Account Binding Keys
r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", authnz(requireEABEnabled(acmeResponder.GetExternalAccountKeys)))
r.MethodFunc("GET", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(acmeResponder.GetExternalAccountKeys)))
r.MethodFunc("POST", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(acmeResponder.CreateExternalAccountKey)))
r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", authnz(requireEABEnabled(acmeResponder.DeleteExternalAccountKey)))
r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", acmeEABMiddleware(acmeResponder.GetExternalAccountKeys))
r.MethodFunc("GET", "/acme/eab/{provisionerName}", acmeEABMiddleware(acmeResponder.GetExternalAccountKeys))
r.MethodFunc("POST", "/acme/eab/{provisionerName}", acmeEABMiddleware(acmeResponder.CreateExternalAccountKey))
r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", acmeEABMiddleware(acmeResponder.DeleteExternalAccountKey))
// Policy - Authority
r.MethodFunc("GET", "/policy", authorityPolicyMiddleware(policyResponder.GetAuthorityPolicy))
r.MethodFunc("POST", "/policy", authorityPolicyMiddleware(policyResponder.CreateAuthorityPolicy))
r.MethodFunc("PUT", "/policy", authorityPolicyMiddleware(policyResponder.UpdateAuthorityPolicy))
r.MethodFunc("DELETE", "/policy", authorityPolicyMiddleware(policyResponder.DeleteAuthorityPolicy))
// Policy - Provisioner
r.MethodFunc("GET", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.GetProvisionerPolicy))
r.MethodFunc("POST", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.CreateProvisionerPolicy))
r.MethodFunc("PUT", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.UpdateProvisionerPolicy))
r.MethodFunc("DELETE", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.DeleteProvisionerPolicy))
// Policy - ACME Account
r.MethodFunc("GET", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.GetACMEAccountPolicy))
r.MethodFunc("GET", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.GetACMEAccountPolicy))
r.MethodFunc("POST", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.CreateACMEAccountPolicy))
r.MethodFunc("POST", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.CreateACMEAccountPolicy))
r.MethodFunc("PUT", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.UpdateACMEAccountPolicy))
r.MethodFunc("PUT", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.UpdateACMEAccountPolicy))
r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.DeleteACMEAccountPolicy))
r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.DeleteACMEAccountPolicy))
}

@ -1,18 +1,23 @@
package api
import (
"context"
"errors"
"net/http"
"github.com/go-chi/chi"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/admin/db/nosql"
"github.com/smallstep/certificates/authority/provisioner"
)
type nextHTTP = func(http.ResponseWriter, *http.Request)
// requireAPIEnabled is a middleware that ensures the Administration API
// is enabled before servicing requests.
func requireAPIEnabled(next nextHTTP) nextHTTP {
func requireAPIEnabled(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !mustAuthority(r.Context()).IsAdminAPIEnabled() {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "administration API not enabled"))
@ -23,8 +28,9 @@ func requireAPIEnabled(next nextHTTP) nextHTTP {
}
// extractAuthorizeTokenAdmin is a middleware that extracts and caches the bearer token.
func extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP {
func extractAuthorizeTokenAdmin(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
tok := r.Header.Get("Authorization")
if tok == "" {
render.Error(w, admin.NewError(admin.ErrorUnauthorizedType,
@ -39,16 +45,104 @@ func extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP {
return
}
ctx = context.WithValue(ctx, adminContextKey, adm)
ctx = linkedca.NewContextWithAdmin(ctx, adm)
next(w, r.WithContext(ctx))
}
}
// ContextKey is the key type for storing and searching for ACME request
// essentials in the context of a request.
type ContextKey string
// loadProvisionerByName is a middleware that searches for a provisioner
// by name and stores it in the context.
func loadProvisionerByName(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var (
p provisioner.Interface
err error
)
const (
// adminContextKey account key
adminContextKey = ContextKey("admin")
)
ctx := r.Context()
auth := mustAuthority(ctx)
adminDB := admin.MustFromContext(ctx)
name := chi.URLParam(r, "provisionerName")
// TODO(hs): distinguish 404 vs. 500
if p, err = auth.LoadProvisionerByName(name); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
return
}
prov, err := adminDB.GetProvisioner(ctx, p.GetID())
if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error retrieving provisioner %s", name))
return
}
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
next(w, r.WithContext(ctx))
}
}
// checkAction checks if an action is supported in standalone or not
func checkAction(next http.HandlerFunc, supportedInStandalone bool) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// actions allowed in standalone mode are always supported
if supportedInStandalone {
next(w, r)
return
}
// when an action is not supported in standalone mode and when
// using a nosql.DB backend, actions are not supported
if _, ok := admin.MustFromContext(r.Context()).(*nosql.DB); ok {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType,
"operation not supported in standalone mode"))
return
}
// continue to next http handler
next(w, r)
}
}
// loadExternalAccountKey is a middleware that searches for an ACME
// External Account Key by reference or keyID and stores it in the context.
func loadExternalAccountKey(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
prov := linkedca.MustProvisionerFromContext(ctx)
acmeDB := acme.MustDatabaseFromContext(ctx)
reference := chi.URLParam(r, "reference")
keyID := chi.URLParam(r, "keyID")
var (
eak *acme.ExternalAccountKey
err error
)
if keyID != "" {
eak, err = acmeDB.GetExternalAccountKey(ctx, prov.GetId(), keyID)
} else {
eak, err = acmeDB.GetExternalAccountKeyByReference(ctx, prov.GetId(), reference)
}
if err != nil {
if errors.Is(err, acme.ErrNotFound) {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found"))
return
}
render.Error(w, admin.WrapErrorISE(err, "error retrieving ACME External Account Key"))
return
}
if eak == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found"))
return
}
linkedEAK := eakToLinked(eak)
ctx = linkedca.NewContextWithExternalAccountKey(ctx, linkedEAK)
next(w, r.WithContext(ctx))
}
}

@ -4,25 +4,32 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/protobuf/types/known/timestamppb"
"go.step.sm/linkedca"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/admin"
"go.step.sm/linkedca"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/smallstep/certificates/authority/admin/db/nosql"
"github.com/smallstep/certificates/authority/provisioner"
)
func TestHandler_requireAPIEnabled(t *testing.T) {
type test struct {
ctx context.Context
auth adminAuthority
next nextHTTP
next http.HandlerFunc
err *admin.Error
statusCode int
}
@ -100,7 +107,7 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
ctx context.Context
auth adminAuthority
req *http.Request
next nextHTTP
next http.HandlerFunc
err *admin.Error
statusCode int
}
@ -150,7 +157,7 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
req.Header["Authorization"] = []string{"token"}
createdAt := time.Now()
var deletedAt time.Time
admin := &linkedca.Admin{
adm := &linkedca.Admin{
Id: "adminID",
AuthorityId: "authorityID",
Subject: "admin",
@ -162,20 +169,15 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
auth := &mockAdminAuthority{
MockAuthorizeAdminToken: func(r *http.Request, token string) (*linkedca.Admin, error) {
assert.Equals(t, "token", token)
return admin, nil
return adm, nil
},
}
next := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
a := ctx.Value(adminContextKey) // verifying that the context now has a linkedca.Admin
adm, ok := a.(*linkedca.Admin)
if !ok {
t.Errorf("expected *linkedca.Admin; got %T", a)
return
}
adm := linkedca.MustAdminFromContext(ctx) // verifying that the context now has a linkedca.Admin
opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})}
if !cmp.Equal(admin, adm, opts...) {
t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(admin, adm, opts...))
if !cmp.Equal(adm, adm, opts...) {
t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(adm, adm, opts...))
}
w.Write(nil) // mock response with status 200
}
@ -218,3 +220,452 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
})
}
}
func TestHandler_loadProvisionerByName(t *testing.T) {
type test struct {
adminDB admin.DB
auth adminAuthority
ctx context.Context
next http.HandlerFunc
err *admin.Error
statusCode int
}
var tests = map[string]func(t *testing.T) test{
"fail/auth.LoadProvisionerByName": func(t *testing.T) test {
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("provisionerName", "provName")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return nil, errors.New("force")
},
}
err := admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName")
err.Message = "error loading provisioner provName: force"
return test{
ctx: ctx,
auth: auth,
adminDB: &admin.MockDB{},
statusCode: 500,
err: err,
}
},
"fail/db.GetProvisioner": func(t *testing.T) test {
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("provisionerName", "provName")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return &provisioner.MockProvisioner{
MgetID: func() string {
return "provID"
},
}, nil
},
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return nil, errors.New("force")
},
}
err := admin.WrapErrorISE(errors.New("force"), "error retrieving provisioner provName")
err.Message = "error retrieving provisioner provName: force"
return test{
ctx: ctx,
auth: auth,
adminDB: db,
statusCode: 500,
err: err,
}
},
"ok": func(t *testing.T) test {
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("provisionerName", "provName")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return &provisioner.MockProvisioner{
MgetID: func() string {
return "provID"
},
}, nil
},
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return &linkedca.Provisioner{
Id: "provID",
Name: "provName",
}, nil
},
}
return test{
ctx: ctx,
auth: auth,
adminDB: db,
statusCode: 200,
next: func(w http.ResponseWriter, r *http.Request) {
prov := linkedca.MustProvisionerFromContext(r.Context())
assert.NotNil(t, prov)
assert.Equals(t, "provID", prov.GetId())
assert.Equals(t, "provName", prov.GetName())
w.Write(nil) // mock response with status 200
},
}
},
}
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
mockMustAuthority(t, tc.auth)
ctx := admin.NewContext(tc.ctx, tc.adminDB)
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
req = req.WithContext(ctx)
w := httptest.NewRecorder()
loadProvisionerByName(tc.next)(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
body, err := io.ReadAll(res.Body)
res.Body.Close()
assert.FatalError(t, err)
if res.StatusCode >= 400 {
err := admin.Error{}
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err))
assert.Equals(t, tc.err.Type, err.Type)
assert.Equals(t, tc.err.Message, err.Message)
assert.Equals(t, tc.err.StatusCode(), res.StatusCode)
assert.Equals(t, tc.err.Detail, err.Detail)
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
return
}
})
}
}
func TestHandler_checkAction(t *testing.T) {
type test struct {
adminDB admin.DB
next http.HandlerFunc
supportedInStandalone bool
err *admin.Error
statusCode int
}
var tests = map[string]func(t *testing.T) test{
"standalone-nosql-supported": func(t *testing.T) test {
return test{
supportedInStandalone: true,
adminDB: &nosql.DB{},
next: func(w http.ResponseWriter, r *http.Request) {
w.Write(nil) // mock response with status 200
},
statusCode: 200,
}
},
"standalone-nosql-not-supported": func(t *testing.T) test {
err := admin.NewError(admin.ErrorNotImplementedType, "operation not supported in standalone mode")
err.Message = "operation not supported in standalone mode"
return test{
supportedInStandalone: false,
adminDB: &nosql.DB{},
statusCode: 501,
err: err,
}
},
"standalone-no-nosql-not-supported": func(t *testing.T) test {
err := admin.NewError(admin.ErrorNotImplementedType, "operation not supported")
err.Message = "operation not supported"
return test{
supportedInStandalone: false,
adminDB: &admin.MockDB{},
next: func(w http.ResponseWriter, r *http.Request) {
w.Write(nil) // mock response with status 200
},
statusCode: 200,
err: err,
}
},
}
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
ctx := admin.NewContext(context.Background(), tc.adminDB)
req := httptest.NewRequest("GET", "/foo", nil).WithContext(ctx)
w := httptest.NewRecorder()
checkAction(tc.next, tc.supportedInStandalone)(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
body, err := io.ReadAll(res.Body)
res.Body.Close()
assert.FatalError(t, err)
if res.StatusCode >= 400 {
err := admin.Error{}
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err))
assert.Equals(t, tc.err.Type, err.Type)
assert.Equals(t, tc.err.Message, err.Message)
assert.Equals(t, tc.err.StatusCode(), res.StatusCode)
assert.Equals(t, tc.err.Detail, err.Detail)
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
return
}
})
}
}
func TestHandler_loadExternalAccountKey(t *testing.T) {
type test struct {
ctx context.Context
acmeDB acme.DB
next http.HandlerFunc
err *admin.Error
statusCode int
}
var tests = map[string]func(t *testing.T) test{
"fail/keyID-not-found-error": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
}
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("keyID", "key")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")
err.Message = "ACME External Account Key not found"
return test{
ctx: ctx,
acmeDB: &acme.MockDB{
MockGetExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, "provID", provisionerID)
assert.Equals(t, "key", keyID)
return nil, acme.ErrNotFound
},
},
err: err,
statusCode: 404,
}
},
"fail/keyID-error": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
}
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("keyID", "key")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
err := admin.WrapErrorISE(errors.New("force"), "error retrieving ACME External Account Key")
err.Message = "error retrieving ACME External Account Key: force"
return test{
ctx: ctx,
acmeDB: &acme.MockDB{
MockGetExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, "provID", provisionerID)
assert.Equals(t, "key", keyID)
return nil, errors.New("force")
},
},
err: err,
statusCode: 500,
}
},
"fail/reference-not-found-error": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
}
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("reference", "ref")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")
err.Message = "ACME External Account Key not found"
return test{
ctx: ctx,
acmeDB: &acme.MockDB{
MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, "provID", provisionerID)
assert.Equals(t, "ref", reference)
return nil, acme.ErrNotFound
},
},
err: err,
statusCode: 404,
}
},
"fail/reference-error": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
}
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("reference", "ref")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
err := admin.WrapErrorISE(errors.New("force"), "error retrieving ACME External Account Key")
err.Message = "error retrieving ACME External Account Key: force"
return test{
ctx: ctx,
acmeDB: &acme.MockDB{
MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, "provID", provisionerID)
assert.Equals(t, "ref", reference)
return nil, errors.New("force")
},
},
err: err,
statusCode: 500,
}
},
"fail/no-key": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
}
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("reference", "ref")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")
err.Message = "ACME External Account Key not found"
return test{
ctx: ctx,
acmeDB: &acme.MockDB{
MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, "provID", provisionerID)
assert.Equals(t, "ref", reference)
return nil, nil
},
},
err: err,
statusCode: 404,
}
},
"ok/keyID": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
}
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("keyID", "eakID")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")
err.Message = "ACME External Account Key not found"
createdAt := time.Now().Add(-1 * time.Hour)
var boundAt time.Time
eak := &acme.ExternalAccountKey{
ID: "eakID",
ProvisionerID: "provID",
CreatedAt: createdAt,
BoundAt: boundAt,
}
return test{
ctx: ctx,
acmeDB: &acme.MockDB{
MockGetExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, "provID", provisionerID)
assert.Equals(t, "eakID", keyID)
return eak, nil
},
},
next: func(w http.ResponseWriter, r *http.Request) {
contextEAK := linkedca.MustExternalAccountKeyFromContext(r.Context())
assert.NotNil(t, eak)
exp := &linkedca.EABKey{
Id: "eakID",
Provisioner: "provID",
CreatedAt: timestamppb.New(createdAt),
BoundAt: timestamppb.New(boundAt),
}
assert.Equals(t, exp, contextEAK)
w.Write(nil) // mock response with status 200
},
err: nil,
statusCode: 200,
}
},
"ok/reference": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
}
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("reference", "ref")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")
err.Message = "ACME External Account Key not found"
createdAt := time.Now().Add(-1 * time.Hour)
var boundAt time.Time
eak := &acme.ExternalAccountKey{
ID: "eakID",
ProvisionerID: "provID",
Reference: "ref",
CreatedAt: createdAt,
BoundAt: boundAt,
}
return test{
ctx: ctx,
acmeDB: &acme.MockDB{
MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, "provID", provisionerID)
assert.Equals(t, "ref", reference)
return eak, nil
},
},
next: func(w http.ResponseWriter, r *http.Request) {
contextEAK := linkedca.MustExternalAccountKeyFromContext(r.Context())
assert.NotNil(t, eak)
exp := &linkedca.EABKey{
Id: "eakID",
Provisioner: "provID",
Reference: "ref",
CreatedAt: timestamppb.New(createdAt),
BoundAt: timestamppb.New(boundAt),
}
assert.Equals(t, exp, contextEAK)
w.Write(nil) // mock response with status 200
},
err: nil,
statusCode: 200,
}
},
}
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
ctx := acme.NewDatabaseContext(tc.ctx, tc.acmeDB)
req := httptest.NewRequest("GET", "/foo", nil)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
loadExternalAccountKey(tc.next)(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
body, err := io.ReadAll(res.Body)
res.Body.Close()
assert.FatalError(t, err)
if res.StatusCode >= 400 {
err := admin.Error{}
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err))
assert.Equals(t, tc.err.Type, err.Type)
assert.Equals(t, tc.err.Message, err.Message)
assert.Equals(t, tc.err.StatusCode(), res.StatusCode)
assert.Equals(t, tc.err.Detail, err.Detail)
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
return
}
})
}
}

@ -0,0 +1,517 @@
package api
import (
"errors"
"net/http"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/policy"
)
type policyAdminResponderInterface interface {
GetAuthorityPolicy(w http.ResponseWriter, r *http.Request)
CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request)
UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request)
DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request)
GetProvisionerPolicy(w http.ResponseWriter, r *http.Request)
CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request)
UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request)
DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request)
GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
}
// PolicyAdminResponder is responsible for writing ACME admin responses
type PolicyAdminResponder struct {
auth adminAuthority
adminDB admin.DB
acmeDB acme.DB
isLinkedCA bool
}
// NewACMEAdminResponder returns a new ACMEAdminResponder
func NewPolicyAdminResponder(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB) *PolicyAdminResponder {
var isLinkedCA bool
if a, ok := adminDB.(interface{ IsLinkedCA() bool }); ok {
isLinkedCA = a.IsLinkedCA()
}
return &PolicyAdminResponder{
auth: auth,
adminDB: adminDB,
acmeDB: acmeDB,
isLinkedCA: isLinkedCA,
}
}
// GetAuthorityPolicy handles the GET /admin/authority/policy request
func (par *PolicyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
if err := par.blockLinkedCA(); err != nil {
render.Error(w, err)
return
}
authorityPolicy, err := par.auth.GetAuthorityPolicy(r.Context())
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy"))
return
}
if authorityPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist"))
return
}
render.ProtoJSONStatus(w, authorityPolicy, http.StatusOK)
}
// CreateAuthorityPolicy handles the POST /admin/authority/policy request
func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
if err := par.blockLinkedCA(); err != nil {
render.Error(w, err)
return
}
ctx := r.Context()
authorityPolicy, err := par.auth.GetAuthorityPolicy(ctx)
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy"))
return
}
if authorityPolicy != nil {
adminErr := admin.NewError(admin.ErrorConflictType, "authority already has a policy")
render.Error(w, adminErr)
return
}
var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return
}
newPolicy.Deduplicate()
if err := validatePolicy(newPolicy); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy"))
return
}
adm := linkedca.MustAdminFromContext(ctx)
var createdPolicy *linkedca.Policy
if createdPolicy, err = par.auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
if isBadRequest(err) {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy"))
return
}
render.Error(w, admin.WrapErrorISE(err, "error storing authority policy"))
return
}
render.ProtoJSONStatus(w, createdPolicy, http.StatusCreated)
}
// UpdateAuthorityPolicy handles the PUT /admin/authority/policy request
func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
if err := par.blockLinkedCA(); err != nil {
render.Error(w, err)
return
}
ctx := r.Context()
authorityPolicy, err := par.auth.GetAuthorityPolicy(ctx)
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy"))
return
}
if authorityPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist"))
return
}
var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return
}
newPolicy.Deduplicate()
if err := validatePolicy(newPolicy); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy"))
return
}
adm := linkedca.MustAdminFromContext(ctx)
var updatedPolicy *linkedca.Policy
if updatedPolicy, err = par.auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
if isBadRequest(err) {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy"))
return
}
render.Error(w, admin.WrapErrorISE(err, "error updating authority policy"))
return
}
render.ProtoJSONStatus(w, updatedPolicy, http.StatusOK)
}
// DeleteAuthorityPolicy handles the DELETE /admin/authority/policy request
func (par *PolicyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
if err := par.blockLinkedCA(); err != nil {
render.Error(w, err)
return
}
ctx := r.Context()
authorityPolicy, err := par.auth.GetAuthorityPolicy(ctx)
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy"))
return
}
if authorityPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist"))
return
}
if err := par.auth.RemoveAuthorityPolicy(ctx); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error deleting authority policy"))
return
}
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
}
// GetProvisionerPolicy handles the GET /admin/provisioners/{name}/policy request
func (par *PolicyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
if err := par.blockLinkedCA(); err != nil {
render.Error(w, err)
return
}
prov := linkedca.MustProvisionerFromContext(r.Context())
provisionerPolicy := prov.GetPolicy()
if provisionerPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
return
}
render.ProtoJSONStatus(w, provisionerPolicy, http.StatusOK)
}
// CreateProvisionerPolicy handles the POST /admin/provisioners/{name}/policy request
func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
if err := par.blockLinkedCA(); err != nil {
render.Error(w, err)
return
}
ctx := r.Context()
prov := linkedca.MustProvisionerFromContext(ctx)
provisionerPolicy := prov.GetPolicy()
if provisionerPolicy != nil {
adminErr := admin.NewError(admin.ErrorConflictType, "provisioner %s already has a policy", prov.Name)
render.Error(w, adminErr)
return
}
var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return
}
newPolicy.Deduplicate()
if err := validatePolicy(newPolicy); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy"))
return
}
prov.Policy = newPolicy
if err := par.auth.UpdateProvisioner(ctx, prov); err != nil {
if isBadRequest(err) {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner policy"))
return
}
render.Error(w, admin.WrapErrorISE(err, "error creating provisioner policy"))
return
}
render.ProtoJSONStatus(w, newPolicy, http.StatusCreated)
}
// UpdateProvisionerPolicy handles the PUT /admin/provisioners/{name}/policy request
func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
if err := par.blockLinkedCA(); err != nil {
render.Error(w, err)
return
}
ctx := r.Context()
prov := linkedca.MustProvisionerFromContext(ctx)
provisionerPolicy := prov.GetPolicy()
if provisionerPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
return
}
var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return
}
newPolicy.Deduplicate()
if err := validatePolicy(newPolicy); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy"))
return
}
prov.Policy = newPolicy
if err := par.auth.UpdateProvisioner(ctx, prov); err != nil {
if isBadRequest(err) {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner policy"))
return
}
render.Error(w, admin.WrapErrorISE(err, "error updating provisioner policy"))
return
}
render.ProtoJSONStatus(w, newPolicy, http.StatusOK)
}
// DeleteProvisionerPolicy handles the DELETE /admin/provisioners/{name}/policy request
func (par *PolicyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
if err := par.blockLinkedCA(); err != nil {
render.Error(w, err)
return
}
ctx := r.Context()
prov := linkedca.MustProvisionerFromContext(ctx)
if prov.Policy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
return
}
// remove the policy
prov.Policy = nil
if err := par.auth.UpdateProvisioner(ctx, prov); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error deleting provisioner policy"))
return
}
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
}
func (par *PolicyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
if err := par.blockLinkedCA(); err != nil {
render.Error(w, err)
return
}
ctx := r.Context()
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
eakPolicy := eak.GetPolicy()
if eakPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
return
}
render.ProtoJSONStatus(w, eakPolicy, http.StatusOK)
}
func (par *PolicyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
if err := par.blockLinkedCA(); err != nil {
render.Error(w, err)
return
}
ctx := r.Context()
prov := linkedca.MustProvisionerFromContext(ctx)
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
eakPolicy := eak.GetPolicy()
if eakPolicy != nil {
adminErr := admin.NewError(admin.ErrorConflictType, "ACME EAK %s already has a policy", eak.Id)
render.Error(w, adminErr)
return
}
var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return
}
newPolicy.Deduplicate()
if err := validatePolicy(newPolicy); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy"))
return
}
eak.Policy = newPolicy
acmeEAK := linkedEAKToCertificates(eak)
if err := par.acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error creating ACME EAK policy"))
return
}
render.ProtoJSONStatus(w, newPolicy, http.StatusCreated)
}
func (par *PolicyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
if err := par.blockLinkedCA(); err != nil {
render.Error(w, err)
return
}
ctx := r.Context()
prov := linkedca.MustProvisionerFromContext(ctx)
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
eakPolicy := eak.GetPolicy()
if eakPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
return
}
var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return
}
newPolicy.Deduplicate()
if err := validatePolicy(newPolicy); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy"))
return
}
eak.Policy = newPolicy
acmeEAK := linkedEAKToCertificates(eak)
if err := par.acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error updating ACME EAK policy"))
return
}
render.ProtoJSONStatus(w, newPolicy, http.StatusOK)
}
func (par *PolicyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
if err := par.blockLinkedCA(); err != nil {
render.Error(w, err)
return
}
ctx := r.Context()
prov := linkedca.MustProvisionerFromContext(ctx)
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
eakPolicy := eak.GetPolicy()
if eakPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
return
}
// remove the policy
eak.Policy = nil
acmeEAK := linkedEAKToCertificates(eak)
if err := par.acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error deleting ACME EAK policy"))
return
}
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
}
// blockLinkedCA blocks all API operations on linked deployments
func (par *PolicyAdminResponder) blockLinkedCA() error {
// temporary blocking linked deployments
if par.isLinkedCA {
return admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments")
}
return nil
}
// isBadRequest checks if an error should result in a bad request error
// returned to the client.
func isBadRequest(err error) bool {
var pe *authority.PolicyError
isPolicyError := errors.As(err, &pe)
return isPolicyError && (pe.Typ == authority.AdminLockOut || pe.Typ == authority.EvaluationFailure || pe.Typ == authority.ConfigurationFailure)
}
func validatePolicy(p *linkedca.Policy) error {
// convert the policy; return early if nil
options := policy.LinkedToCertificates(p)
if options == nil {
return nil
}
var err error
// Initialize a temporary x509 allow/deny policy engine
if _, err = policy.NewX509PolicyEngine(options.GetX509Options()); err != nil {
return err
}
// Initialize a temporary SSH allow/deny policy engine for host certificates
if _, err = policy.NewSSHHostPolicyEngine(options.GetSSHOptions()); err != nil {
return err
}
// Initialize a temporary SSH allow/deny policy engine for user certificates
if _, err = policy.NewSSHUserPolicyEngine(options.GetSSHOptions()); err != nil {
return err
}
return nil
}

File diff suppressed because it is too large Load Diff

@ -8,18 +8,21 @@ import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/go-chi/chi"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/timestamppb"
"go.step.sm/linkedca"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/linkedca"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/timestamppb"
)
func TestHandler_GetProvisioner(t *testing.T) {
@ -333,12 +336,12 @@ func TestHandler_CreateProvisioner(t *testing.T) {
return test{
ctx: context.Background(),
body: body,
statusCode: 500,
err: &admin.Error{ // TODO(hs): this probably needs a better error
Type: "",
Status: 500,
Detail: "",
Message: "",
statusCode: 400,
err: &admin.Error{
Type: "badRequest",
Status: 400,
Detail: "bad request",
Message: "proto: syntax error (line 1:2): invalid value !",
},
}
},
@ -419,9 +422,15 @@ func TestHandler_CreateProvisioner(t *testing.T) {
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr))
assert.Equals(t, tc.err.Type, adminErr.Type)
assert.Equals(t, tc.err.Message, adminErr.Message)
assert.Equals(t, tc.err.Detail, adminErr.Detail)
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
if strings.HasPrefix(tc.err.Message, "proto:") {
assert.True(t, strings.Contains(adminErr.Message, "syntax error"))
} else {
assert.Equals(t, tc.err.Message, adminErr.Message)
}
return
}
@ -611,12 +620,12 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
ctx: context.Background(),
body: body,
adminDB: &admin.MockDB{},
statusCode: 500,
err: &admin.Error{ // TODO(hs): this probably needs a better error
Type: "",
Status: 500,
Detail: "",
Message: "",
statusCode: 400,
err: &admin.Error{
Type: "badRequest",
Status: 400,
Detail: "bad request",
Message: "proto: syntax error (line 1:2): invalid value !",
},
}
},
@ -1068,9 +1077,15 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr))
assert.Equals(t, tc.err.Type, adminErr.Type)
assert.Equals(t, tc.err.Message, adminErr.Message)
assert.Equals(t, tc.err.Detail, adminErr.Detail)
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
if strings.HasPrefix(tc.err.Message, "proto:") {
assert.True(t, strings.Contains(adminErr.Message, "syntax error"))
} else {
assert.Equals(t, tc.err.Message, adminErr.Message)
}
return
}

@ -69,6 +69,11 @@ type DB interface {
GetAdmins(ctx context.Context) ([]*linkedca.Admin, error)
UpdateAdmin(ctx context.Context, admin *linkedca.Admin) error
DeleteAdmin(ctx context.Context, id string) error
CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error
GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error)
UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error
DeleteAuthorityPolicy(ctx context.Context) error
}
type dbKey struct{}
@ -109,6 +114,11 @@ type MockDB struct {
MockUpdateAdmin func(ctx context.Context, adm *linkedca.Admin) error
MockDeleteAdmin func(ctx context.Context, id string) error
MockCreateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error
MockGetAuthorityPolicy func(ctx context.Context) (*linkedca.Policy, error)
MockUpdateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error
MockDeleteAuthorityPolicy func(ctx context.Context) error
MockError error
MockRet1 interface{}
}
@ -202,3 +212,35 @@ func (m *MockDB) DeleteAdmin(ctx context.Context, id string) error {
}
return m.MockError
}
// CreateAuthorityPolicy mock
func (m *MockDB) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
if m.MockCreateAuthorityPolicy != nil {
return m.MockCreateAuthorityPolicy(ctx, policy)
}
return m.MockError
}
// GetAuthorityPolicy mock
func (m *MockDB) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
if m.MockGetAuthorityPolicy != nil {
return m.MockGetAuthorityPolicy(ctx)
}
return m.MockRet1.(*linkedca.Policy), m.MockError
}
// UpdateAuthorityPolicy mock
func (m *MockDB) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
if m.MockUpdateAuthorityPolicy != nil {
return m.MockUpdateAuthorityPolicy(ctx, policy)
}
return m.MockError
}
// DeleteAuthorityPolicy mock
func (m *MockDB) DeleteAuthorityPolicy(ctx context.Context) error {
if m.MockDeleteAuthorityPolicy != nil {
return m.MockDeleteAuthorityPolicy(ctx)
}
return m.MockError
}

@ -11,8 +11,9 @@ import (
)
var (
adminsTable = []byte("admins")
provisionersTable = []byte("provisioners")
adminsTable = []byte("admins")
provisionersTable = []byte("provisioners")
authorityPoliciesTable = []byte("authority_policies")
)
// DB is a struct that implements the AdminDB interface.
@ -23,7 +24,7 @@ type DB struct {
// New configures and returns a new Authority DB backend implemented using a nosql DB.
func New(db nosqlDB.DB, authorityID string) (*DB, error) {
tables := [][]byte{adminsTable, provisionersTable}
tables := [][]byte{adminsTable, provisionersTable, authorityPoliciesTable}
for _, b := range tables {
if err := db.CreateTable(b); err != nil {
return nil, errors.Wrapf(err, "error creating table %s",

@ -0,0 +1,339 @@
package nosql
import (
"context"
"encoding/json"
"fmt"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/nosql"
)
type dbX509Policy struct {
Allow *dbX509Names `json:"allow,omitempty"`
Deny *dbX509Names `json:"deny,omitempty"`
AllowWildcardNames bool `json:"allow_wildcard_names,omitempty"`
}
type dbX509Names struct {
CommonNames []string `json:"cn,omitempty"`
DNSDomains []string `json:"dns,omitempty"`
IPRanges []string `json:"ip,omitempty"`
EmailAddresses []string `json:"email,omitempty"`
URIDomains []string `json:"uri,omitempty"`
}
type dbSSHPolicy struct {
// User contains SSH user certificate options.
User *dbSSHUserPolicy `json:"user,omitempty"`
// Host contains SSH host certificate options.
Host *dbSSHHostPolicy `json:"host,omitempty"`
}
type dbSSHHostPolicy struct {
Allow *dbSSHHostNames `json:"allow,omitempty"`
Deny *dbSSHHostNames `json:"deny,omitempty"`
}
type dbSSHHostNames struct {
DNSDomains []string `json:"dns,omitempty"`
IPRanges []string `json:"ip,omitempty"`
Principals []string `json:"principal,omitempty"`
}
type dbSSHUserPolicy struct {
Allow *dbSSHUserNames `json:"allow,omitempty"`
Deny *dbSSHUserNames `json:"deny,omitempty"`
}
type dbSSHUserNames struct {
EmailAddresses []string `json:"email,omitempty"`
Principals []string `json:"principal,omitempty"`
}
type dbPolicy struct {
X509 *dbX509Policy `json:"x509,omitempty"`
SSH *dbSSHPolicy `json:"ssh,omitempty"`
}
type dbAuthorityPolicy struct {
ID string `json:"id"`
AuthorityID string `json:"authorityID"`
Policy *dbPolicy `json:"policy,omitempty"`
}
func (dbap *dbAuthorityPolicy) convert() *linkedca.Policy {
if dbap == nil {
return nil
}
return dbToLinked(dbap.Policy)
}
func (db *DB) getDBAuthorityPolicyBytes(ctx context.Context, authorityID string) ([]byte, error) {
data, err := db.db.Get(authorityPoliciesTable, []byte(authorityID))
if nosql.IsErrNotFound(err) {
return nil, admin.NewError(admin.ErrorNotFoundType, "authority policy not found")
} else if err != nil {
return nil, fmt.Errorf("error loading authority policy: %w", err)
}
return data, nil
}
func (db *DB) unmarshalDBAuthorityPolicy(data []byte) (*dbAuthorityPolicy, error) {
if len(data) == 0 {
return nil, nil
}
var dba = new(dbAuthorityPolicy)
if err := json.Unmarshal(data, dba); err != nil {
return nil, fmt.Errorf("error unmarshaling policy bytes into dbAuthorityPolicy: %w", err)
}
return dba, nil
}
func (db *DB) getDBAuthorityPolicy(ctx context.Context, authorityID string) (*dbAuthorityPolicy, error) {
data, err := db.getDBAuthorityPolicyBytes(ctx, authorityID)
if err != nil {
return nil, err
}
dbap, err := db.unmarshalDBAuthorityPolicy(data)
if err != nil {
return nil, err
}
if dbap == nil {
return nil, nil
}
if dbap.AuthorityID != authorityID {
return nil, admin.NewError(admin.ErrorAuthorityMismatchType,
"authority policy is not owned by authority %s", authorityID)
}
return dbap, nil
}
func (db *DB) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
dbap := &dbAuthorityPolicy{
ID: db.authorityID,
AuthorityID: db.authorityID,
Policy: linkedToDB(policy),
}
if err := db.save(ctx, dbap.ID, dbap, nil, "authority_policy", authorityPoliciesTable); err != nil {
return admin.WrapErrorISE(err, "error creating authority policy")
}
return nil
}
func (db *DB) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
dbap, err := db.getDBAuthorityPolicy(ctx, db.authorityID)
if err != nil {
return nil, err
}
return dbap.convert(), nil
}
func (db *DB) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
old, err := db.getDBAuthorityPolicy(ctx, db.authorityID)
if err != nil {
return err
}
dbap := &dbAuthorityPolicy{
ID: db.authorityID,
AuthorityID: db.authorityID,
Policy: linkedToDB(policy),
}
if err := db.save(ctx, dbap.ID, dbap, old, "authority_policy", authorityPoliciesTable); err != nil {
return admin.WrapErrorISE(err, "error updating authority policy")
}
return nil
}
func (db *DB) DeleteAuthorityPolicy(ctx context.Context) error {
old, err := db.getDBAuthorityPolicy(ctx, db.authorityID)
if err != nil {
return err
}
if err := db.save(ctx, old.ID, nil, old, "authority_policy", authorityPoliciesTable); err != nil {
return admin.WrapErrorISE(err, "error deleting authority policy")
}
return nil
}
func dbToLinked(p *dbPolicy) *linkedca.Policy {
if p == nil {
return nil
}
r := &linkedca.Policy{}
if x509 := p.X509; x509 != nil {
r.X509 = &linkedca.X509Policy{}
if allow := x509.Allow; allow != nil {
r.X509.Allow = &linkedca.X509Names{}
r.X509.Allow.Dns = allow.DNSDomains
r.X509.Allow.Emails = allow.EmailAddresses
r.X509.Allow.Ips = allow.IPRanges
r.X509.Allow.Uris = allow.URIDomains
r.X509.Allow.CommonNames = allow.CommonNames
}
if deny := x509.Deny; deny != nil {
r.X509.Deny = &linkedca.X509Names{}
r.X509.Deny.Dns = deny.DNSDomains
r.X509.Deny.Emails = deny.EmailAddresses
r.X509.Deny.Ips = deny.IPRanges
r.X509.Deny.Uris = deny.URIDomains
r.X509.Deny.CommonNames = deny.CommonNames
}
r.X509.AllowWildcardNames = x509.AllowWildcardNames
}
if ssh := p.SSH; ssh != nil {
r.Ssh = &linkedca.SSHPolicy{}
if host := ssh.Host; host != nil {
r.Ssh.Host = &linkedca.SSHHostPolicy{}
if allow := host.Allow; allow != nil {
r.Ssh.Host.Allow = &linkedca.SSHHostNames{}
r.Ssh.Host.Allow.Dns = allow.DNSDomains
r.Ssh.Host.Allow.Ips = allow.IPRanges
r.Ssh.Host.Allow.Principals = allow.Principals
}
if deny := host.Deny; deny != nil {
r.Ssh.Host.Deny = &linkedca.SSHHostNames{}
r.Ssh.Host.Deny.Dns = deny.DNSDomains
r.Ssh.Host.Deny.Ips = deny.IPRanges
r.Ssh.Host.Deny.Principals = deny.Principals
}
}
if user := ssh.User; user != nil {
r.Ssh.User = &linkedca.SSHUserPolicy{}
if allow := user.Allow; allow != nil {
r.Ssh.User.Allow = &linkedca.SSHUserNames{}
r.Ssh.User.Allow.Emails = allow.EmailAddresses
r.Ssh.User.Allow.Principals = allow.Principals
}
if deny := user.Deny; deny != nil {
r.Ssh.User.Deny = &linkedca.SSHUserNames{}
r.Ssh.User.Deny.Emails = deny.EmailAddresses
r.Ssh.User.Deny.Principals = deny.Principals
}
}
}
return r
}
func linkedToDB(p *linkedca.Policy) *dbPolicy {
if p == nil {
return nil
}
// return early if x509 nor SSH is set
if p.GetX509() == nil && p.GetSsh() == nil {
return nil
}
r := &dbPolicy{}
// fill x509 policy configuration
if x509 := p.GetX509(); x509 != nil {
r.X509 = &dbX509Policy{}
if allow := x509.GetAllow(); allow != nil {
r.X509.Allow = &dbX509Names{}
if allow.Dns != nil {
r.X509.Allow.DNSDomains = allow.Dns
}
if allow.Ips != nil {
r.X509.Allow.IPRanges = allow.Ips
}
if allow.Emails != nil {
r.X509.Allow.EmailAddresses = allow.Emails
}
if allow.Uris != nil {
r.X509.Allow.URIDomains = allow.Uris
}
if allow.CommonNames != nil {
r.X509.Allow.CommonNames = allow.CommonNames
}
}
if deny := x509.GetDeny(); deny != nil {
r.X509.Deny = &dbX509Names{}
if deny.Dns != nil {
r.X509.Deny.DNSDomains = deny.Dns
}
if deny.Ips != nil {
r.X509.Deny.IPRanges = deny.Ips
}
if deny.Emails != nil {
r.X509.Deny.EmailAddresses = deny.Emails
}
if deny.Uris != nil {
r.X509.Deny.URIDomains = deny.Uris
}
if deny.CommonNames != nil {
r.X509.Deny.CommonNames = deny.CommonNames
}
}
r.X509.AllowWildcardNames = x509.GetAllowWildcardNames()
}
// fill ssh policy configuration
if ssh := p.GetSsh(); ssh != nil {
r.SSH = &dbSSHPolicy{}
if host := ssh.GetHost(); host != nil {
r.SSH.Host = &dbSSHHostPolicy{}
if allow := host.GetAllow(); allow != nil {
r.SSH.Host.Allow = &dbSSHHostNames{}
if allow.Dns != nil {
r.SSH.Host.Allow.DNSDomains = allow.Dns
}
if allow.Ips != nil {
r.SSH.Host.Allow.IPRanges = allow.Ips
}
if allow.Principals != nil {
r.SSH.Host.Allow.Principals = allow.Principals
}
}
if deny := host.GetDeny(); deny != nil {
r.SSH.Host.Deny = &dbSSHHostNames{}
if deny.Dns != nil {
r.SSH.Host.Deny.DNSDomains = deny.Dns
}
if deny.Ips != nil {
r.SSH.Host.Deny.IPRanges = deny.Ips
}
if deny.Principals != nil {
r.SSH.Host.Deny.Principals = deny.Principals
}
}
}
if user := ssh.GetUser(); user != nil {
r.SSH.User = &dbSSHUserPolicy{}
if allow := user.GetAllow(); allow != nil {
r.SSH.User.Allow = &dbSSHUserNames{}
if allow.Emails != nil {
r.SSH.User.Allow.EmailAddresses = allow.Emails
}
if allow.Principals != nil {
r.SSH.User.Allow.Principals = allow.Principals
}
}
if deny := user.GetDeny(); deny != nil {
r.SSH.User.Deny = &dbSSHUserNames{}
if deny.Emails != nil {
r.SSH.User.Deny.EmailAddresses = deny.Emails
}
if deny.Principals != nil {
r.SSH.User.Deny.Principals = deny.Principals
}
}
}
}
return r
}

File diff suppressed because it is too large Load Diff

@ -24,10 +24,12 @@ const (
ErrorBadRequestType
// ErrorNotImplementedType not implemented.
ErrorNotImplementedType
// ErrorUnauthorizedType internal server error.
// ErrorUnauthorizedType unauthorized.
ErrorUnauthorizedType
// ErrorServerInternalType internal server error.
ErrorServerInternalType
// ErrorConflictType conflict.
ErrorConflictType
)
// String returns the string representation of the admin problem type,
@ -48,6 +50,8 @@ func (ap ProblemType) String() string {
return "unauthorized"
case ErrorServerInternalType:
return "internalServerError"
case ErrorConflictType:
return "conflict"
default:
return fmt.Sprintf("unsupported error type '%d'", int(ap))
}
@ -64,7 +68,7 @@ var (
errorServerInternalMetadata = errorMetadata{
typ: ErrorServerInternalType.String(),
details: "the server experienced an internal error",
status: 500,
status: http.StatusInternalServerError,
}
errorMap = map[ProblemType]errorMetadata{
ErrorNotFoundType: {
@ -98,6 +102,11 @@ var (
status: http.StatusUnauthorized,
},
ErrorServerInternalType: errorServerInternalMetadata,
ErrorConflictType: {
typ: ErrorConflictType.String(),
details: "conflict",
status: http.StatusConflict,
},
}
)

@ -59,12 +59,12 @@ func newSubProv(subject, prov string) subProv {
return subProv{subject, prov}
}
// LoadBySubProv a admin by the subject and provisioner name.
// LoadBySubProv loads an admin by subject and provisioner name.
func (c *Collection) LoadBySubProv(sub, provName string) (*linkedca.Admin, bool) {
return loadAdmin(c.bySubProv, newSubProv(sub, provName))
}
// LoadByProvisioner a admin by the subject and provisioner name.
// LoadByProvisioner loads admins by provisioner name.
func (c *Collection) LoadByProvisioner(provName string) ([]*linkedca.Admin, bool) {
val, ok := c.byProv.Load(provName)
if !ok {
@ -78,7 +78,7 @@ func (c *Collection) LoadByProvisioner(provName string) ([]*linkedca.Admin, bool
}
// Store adds an admin to the collection and enforces the uniqueness of
// admin IDs and amdin subject <-> provisioner name combos.
// admin IDs and admin subject <-> provisioner name combos.
func (c *Collection) Store(adm *linkedca.Admin, prov provisioner.Interface) error {
// Input validation.
if adm.ProvisionerId != prov.GetID() {

@ -12,10 +12,16 @@ import (
"time"
"github.com/pkg/errors"
"golang.org/x/crypto/ssh"
"go.step.sm/crypto/pemutil"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/authority/admin"
adminDBNosql "github.com/smallstep/certificates/authority/admin/db/nosql"
"github.com/smallstep/certificates/authority/administrator"
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/cas"
casapi "github.com/smallstep/certificates/cas/apiv1"
@ -26,9 +32,6 @@ import (
"github.com/smallstep/certificates/scep"
"github.com/smallstep/certificates/templates"
"github.com/smallstep/nosql"
"go.step.sm/crypto/pemutil"
"go.step.sm/linkedca"
"golang.org/x/crypto/ssh"
)
// Authority implements the Certificate Authority internal interface.
@ -77,6 +80,9 @@ type Authority struct {
authorizeRenewFunc provisioner.AuthorizeRenewFunc
authorizeSSHRenewFunc provisioner.AuthorizeSSHRenewFunc
// Policy engines
policyEngine *policy.Engine
adminMutex sync.RWMutex
}
@ -232,6 +238,7 @@ func (a *Authority) reloadAdminResources(ctx context.Context) error {
a.provisioners = provClxn
a.config.AuthorityConfig.Admins = adminList
a.admins = adminClxn
return nil
}
@ -578,6 +585,11 @@ func (a *Authority) init() error {
return err
}
// Load x509 and SSH Policy Engines
if err := a.reloadPolicyEngines(context.Background()); err != nil {
return err
}
// Configure templates, currently only ssh templates are supported.
if a.sshCAHostCertSignKey != nil || a.sshCAUserCertSignKey != nil {
a.templates = a.config.Templates

@ -491,7 +491,7 @@ func TestAuthority_authorizeSign(t *testing.T) {
}
} else {
if assert.Nil(t, tc.err) {
assert.Len(t, 8, got)
assert.Equals(t, 9, len(got)) // number of provisioner.SignOptions returned
}
}
})
@ -1034,7 +1034,7 @@ func TestAuthority_authorizeSSHSign(t *testing.T) {
}
} else {
if assert.Nil(t, tc.err) {
assert.Len(t, 7, got)
assert.Len(t, 8, got) // number of provisioner.SignOptions returned
}
}
})

@ -8,12 +8,15 @@ import (
"time"
"github.com/pkg/errors"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/authority/provisioner"
cas "github.com/smallstep/certificates/cas/apiv1"
"github.com/smallstep/certificates/db"
kms "github.com/smallstep/certificates/kms/apiv1"
"github.com/smallstep/certificates/templates"
"go.step.sm/linkedca"
)
const (
@ -95,6 +98,7 @@ type AuthConfig struct {
Admins []*linkedca.Admin `json:"-"`
Template *ASN1DN `json:"template,omitempty"`
Claims *provisioner.Claims `json:"claims,omitempty"`
Policy *policy.Options `json:"policy,omitempty"`
DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"`
Backdate *provisioner.Duration `json:"backdate,omitempty"`
EnableAdmin bool `json:"enableAdmin,omitempty"`

@ -15,16 +15,19 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"golang.org/x/crypto/ssh"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/tlsutil"
"go.step.sm/crypto/x509util"
"go.step.sm/linkedca"
"golang.org/x/crypto/ssh"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
)
const uuidPattern = "^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}$"
@ -35,6 +38,9 @@ type linkedCaClient struct {
authorityID string
}
// interface guard
var _ admin.DB = (*linkedCaClient)(nil)
type linkedCAClaims struct {
jose.Claims
SANs []string `json:"sans"`
@ -116,6 +122,13 @@ func newLinkedCAClient(token string) (*linkedCaClient, error) {
}, nil
}
// IsLinkedCA is a sentinel function that can be used to
// check if a linkedCaClient is the underlying type of an
// admin.DB interface.
func (c *linkedCaClient) IsLinkedCA() bool {
return true
}
func (c *linkedCaClient) Run() {
c.renewer.Run()
}
@ -340,6 +353,22 @@ func (c *linkedCaClient) IsSSHRevoked(serial string) (bool, error) {
return resp.Status != linkedca.RevocationStatus_ACTIVE, nil
}
func (c *linkedCaClient) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
return errors.New("not implemented yet")
}
func (c *linkedCaClient) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
return nil, errors.New("not implemented yet")
}
func (c *linkedCaClient) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
return errors.New("not implemented yet")
}
func (c *linkedCaClient) DeleteAuthorityPolicy(ctx context.Context) error {
return errors.New("not implemented yet")
}
func createProvisionerIdentity(prov provisioner.Interface) *linkedca.ProvisionerIdentity {
if prov == nil {
return nil

@ -0,0 +1,258 @@
package authority
import (
"context"
"errors"
"fmt"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/authority/admin"
authPolicy "github.com/smallstep/certificates/authority/policy"
policy "github.com/smallstep/certificates/policy"
)
type policyErrorType int
const (
AdminLockOut policyErrorType = iota + 1
StoreFailure
ReloadFailure
ConfigurationFailure
EvaluationFailure
InternalFailure
)
type PolicyError struct {
Typ policyErrorType
Err error
}
func (p *PolicyError) Error() string {
return p.Err.Error()
}
func (a *Authority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
a.adminMutex.Lock()
defer a.adminMutex.Unlock()
p, err := a.adminDB.GetAuthorityPolicy(ctx)
if err != nil {
return nil, &PolicyError{
Typ: InternalFailure,
Err: err,
}
}
return p, nil
}
func (a *Authority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, p *linkedca.Policy) (*linkedca.Policy, error) {
a.adminMutex.Lock()
defer a.adminMutex.Unlock()
if err := a.checkAuthorityPolicy(ctx, adm, p); err != nil {
return nil, err
}
if err := a.adminDB.CreateAuthorityPolicy(ctx, p); err != nil {
return nil, &PolicyError{
Typ: StoreFailure,
Err: err,
}
}
if err := a.reloadPolicyEngines(ctx); err != nil {
return nil, &PolicyError{
Typ: ReloadFailure,
Err: fmt.Errorf("error reloading policy engines when creating authority policy: %w", err),
}
}
return p, nil
}
func (a *Authority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, p *linkedca.Policy) (*linkedca.Policy, error) {
a.adminMutex.Lock()
defer a.adminMutex.Unlock()
if err := a.checkAuthorityPolicy(ctx, adm, p); err != nil {
return nil, err
}
if err := a.adminDB.UpdateAuthorityPolicy(ctx, p); err != nil {
return nil, &PolicyError{
Typ: StoreFailure,
Err: err,
}
}
if err := a.reloadPolicyEngines(ctx); err != nil {
return nil, &PolicyError{
Typ: ReloadFailure,
Err: fmt.Errorf("error reloading policy engines when updating authority policy: %w", err),
}
}
return p, nil
}
func (a *Authority) RemoveAuthorityPolicy(ctx context.Context) error {
a.adminMutex.Lock()
defer a.adminMutex.Unlock()
if err := a.adminDB.DeleteAuthorityPolicy(ctx); err != nil {
return &PolicyError{
Typ: StoreFailure,
Err: err,
}
}
if err := a.reloadPolicyEngines(ctx); err != nil {
return &PolicyError{
Typ: ReloadFailure,
Err: fmt.Errorf("error reloading policy engines when deleting authority policy: %w", err),
}
}
return nil
}
func (a *Authority) checkAuthorityPolicy(ctx context.Context, currentAdmin *linkedca.Admin, p *linkedca.Policy) error {
// no policy and thus nothing to evaluate; return early
if p == nil {
return nil
}
// get all current admins from the database
allAdmins, err := a.adminDB.GetAdmins(ctx)
if err != nil {
return &PolicyError{
Typ: InternalFailure,
Err: fmt.Errorf("error retrieving admins: %w", err),
}
}
return a.checkPolicy(ctx, currentAdmin, allAdmins, p)
}
func (a *Authority) checkProvisionerPolicy(ctx context.Context, currentAdmin *linkedca.Admin, provName string, p *linkedca.Policy) error {
// no policy and thus nothing to evaluate; return early
if p == nil {
return nil
}
// get all admins for the provisioner; ignoring case in which they're not found
allProvisionerAdmins, _ := a.admins.LoadByProvisioner(provName)
return a.checkPolicy(ctx, currentAdmin, allProvisionerAdmins, p)
}
// checkPolicy checks if a new or updated policy configuration results in the user
// locking themselves or other admins out of the CA.
func (a *Authority) checkPolicy(ctx context.Context, currentAdmin *linkedca.Admin, otherAdmins []*linkedca.Admin, p *linkedca.Policy) error {
// convert the policy; return early if nil
policyOptions := authPolicy.LinkedToCertificates(p)
if policyOptions == nil {
return nil
}
engine, err := authPolicy.NewX509PolicyEngine(policyOptions.GetX509Options())
if err != nil {
return &PolicyError{
Typ: ConfigurationFailure,
Err: err,
}
}
// when an empty X.509 policy is provided, the resulting engine is nil
// and there's no policy to evaluate.
if engine == nil {
return nil
}
// TODO(hs): Provide option to force the policy, even when the admin subject would be locked out?
// check if the admin user that instructed the authority policy to be
// created or updated, would still be allowed when the provided policy
// would be applied.
sans := []string{currentAdmin.GetSubject()}
if err := isAllowed(engine, sans); err != nil {
return err
}
// loop through admins to verify that none of them would be
// locked out when the new policy were to be applied. Returns
// an error with a message that includes the admin subject that
// would be locked out.
for _, adm := range otherAdmins {
sans = []string{adm.GetSubject()}
if err := isAllowed(engine, sans); err != nil {
return err
}
}
// TODO(hs): mask the error message for non-super admins?
return nil
}
// reloadPolicyEngines reloads x509 and SSH policy engines using
// configuration stored in the DB or from the configuration file.
func (a *Authority) reloadPolicyEngines(ctx context.Context) error {
var (
err error
policyOptions *authPolicy.Options
)
if a.config.AuthorityConfig.EnableAdmin {
// temporarily disable policy loading when LinkedCA is in use
if _, ok := a.adminDB.(*linkedCaClient); ok {
return nil
}
linkedPolicy, err := a.adminDB.GetAuthorityPolicy(ctx)
if err != nil {
var ae *admin.Error
if isAdminError := errors.As(err, &ae); (isAdminError && ae.Type != admin.ErrorNotFoundType.String()) || !isAdminError {
return fmt.Errorf("error getting policy to (re)load policy engines: %w", err)
}
}
policyOptions = authPolicy.LinkedToCertificates(linkedPolicy)
} else {
policyOptions = a.config.AuthorityConfig.Policy
}
engine, err := authPolicy.New(policyOptions)
if err != nil {
return err
}
// only update the policy engine when no error was returned
a.policyEngine = engine
return nil
}
func isAllowed(engine authPolicy.X509Policy, sans []string) error {
if err := engine.AreSANsAllowed(sans); err != nil {
var policyErr *policy.NamePolicyError
isNamePolicyError := errors.As(err, &policyErr)
if isNamePolicyError && policyErr.Reason == policy.NotAllowed {
return &PolicyError{
Typ: AdminLockOut,
Err: fmt.Errorf("the provided policy would lock out %s from the CA. Please update your policy to include %s as an allowed name", sans, sans),
}
}
return &PolicyError{
Typ: EvaluationFailure,
Err: err,
}
}
return nil
}

@ -0,0 +1,114 @@
package policy
import (
"crypto/x509"
"errors"
"fmt"
"golang.org/x/crypto/ssh"
)
// Engine is a container for multiple policies.
type Engine struct {
x509Policy X509Policy
sshUserPolicy UserPolicy
sshHostPolicy HostPolicy
}
// New returns a new Engine using Options.
func New(options *Options) (*Engine, error) {
// if no options provided, return early
if options == nil {
return nil, nil
}
var (
x509Policy X509Policy
sshHostPolicy HostPolicy
sshUserPolicy UserPolicy
err error
)
// initialize the x509 allow/deny policy engine
if x509Policy, err = NewX509PolicyEngine(options.GetX509Options()); err != nil {
return nil, err
}
// initialize the SSH allow/deny policy engine for host certificates
if sshHostPolicy, err = NewSSHHostPolicyEngine(options.GetSSHOptions()); err != nil {
return nil, err
}
// initialize the SSH allow/deny policy engine for user certificates
if sshUserPolicy, err = NewSSHUserPolicyEngine(options.GetSSHOptions()); err != nil {
return nil, err
}
return &Engine{
x509Policy: x509Policy,
sshHostPolicy: sshHostPolicy,
sshUserPolicy: sshUserPolicy,
}, nil
}
// IsX509CertificateAllowed evaluates an X.509 certificate against
// the X.509 policy (if available) and returns an error if one of the
// names in the certificate is not allowed.
func (e *Engine) IsX509CertificateAllowed(cert *x509.Certificate) error {
// return early if there's no policy to evaluate
if e == nil || e.x509Policy == nil {
return nil
}
// return result of X.509 policy evaluation
return e.x509Policy.IsX509CertificateAllowed(cert)
}
// AreSANsAllowed evaluates the slice of SANs against the X.509 policy
// (if available) and returns an error if one of the SANs is not allowed.
func (e *Engine) AreSANsAllowed(sans []string) error {
// return early if there's no policy to evaluate
if e == nil || e.x509Policy == nil {
return nil
}
// return result of X.509 policy evaluation
return e.x509Policy.AreSANsAllowed(sans)
}
// IsSSHCertificateAllowed evaluates an SSH certificate against the
// user or host policy (if configured) and returns an error if one of the
// principals in the certificate is not allowed.
func (e *Engine) IsSSHCertificateAllowed(cert *ssh.Certificate) error {
// return early if there's no policy to evaluate
if e == nil || (e.sshHostPolicy == nil && e.sshUserPolicy == nil) {
return nil
}
switch cert.CertType {
case ssh.HostCert:
// when no host policy engine is configured, but a user policy engine is
// configured, the host certificate is denied.
if e.sshHostPolicy == nil && e.sshUserPolicy != nil {
return errors.New("authority not allowed to sign ssh host certificates")
}
// return result of SSH host policy evaluation
return e.sshHostPolicy.IsSSHCertificateAllowed(cert)
case ssh.UserCert:
// when no user policy engine is configured, but a host policy engine is
// configured, the user certificate is denied.
if e.sshUserPolicy == nil && e.sshHostPolicy != nil {
return errors.New("authority not allowed to sign ssh user certificates")
}
// return result of SSH user policy evaluation
return e.sshUserPolicy.IsSSHCertificateAllowed(cert)
default:
return fmt.Errorf("unexpected ssh certificate type %q", cert.CertType)
}
}

@ -0,0 +1,194 @@
package policy
// Options is a container for authority level x509 and SSH
// policy configuration.
type Options struct {
X509 *X509PolicyOptions `json:"x509,omitempty"`
SSH *SSHPolicyOptions `json:"ssh,omitempty"`
}
// GetX509Options returns the x509 authority level policy
// configuration
func (o *Options) GetX509Options() *X509PolicyOptions {
if o == nil {
return nil
}
return o.X509
}
// GetSSHOptions returns the SSH authority level policy
// configuration
func (o *Options) GetSSHOptions() *SSHPolicyOptions {
if o == nil {
return nil
}
return o.SSH
}
// X509PolicyOptionsInterface is an interface for providers
// of x509 allowed and denied names.
type X509PolicyOptionsInterface interface {
GetAllowedNameOptions() *X509NameOptions
GetDeniedNameOptions() *X509NameOptions
AreWildcardNamesAllowed() bool
}
// X509PolicyOptions is a container for x509 allowed and denied
// names.
type X509PolicyOptions struct {
// AllowedNames contains the x509 allowed names
AllowedNames *X509NameOptions `json:"allow,omitempty"`
// DeniedNames contains the x509 denied names
DeniedNames *X509NameOptions `json:"deny,omitempty"`
// AllowWildcardNames indicates if literal wildcard names
// like *.example.com are allowed. Defaults to false.
AllowWildcardNames bool `json:"allowWildcardNames,omitempty"`
}
// X509NameOptions models the X509 name policy configuration.
type X509NameOptions struct {
CommonNames []string `json:"cn,omitempty"`
DNSDomains []string `json:"dns,omitempty"`
IPRanges []string `json:"ip,omitempty"`
EmailAddresses []string `json:"email,omitempty"`
URIDomains []string `json:"uri,omitempty"`
}
// HasNames checks if the AllowedNameOptions has one or more
// names configured.
func (o *X509NameOptions) HasNames() bool {
return len(o.CommonNames) > 0 ||
len(o.DNSDomains) > 0 ||
len(o.IPRanges) > 0 ||
len(o.EmailAddresses) > 0 ||
len(o.URIDomains) > 0
}
// GetAllowedNameOptions returns x509 allowed name policy configuration
func (o *X509PolicyOptions) GetAllowedNameOptions() *X509NameOptions {
if o == nil {
return nil
}
return o.AllowedNames
}
// GetDeniedNameOptions returns the x509 denied name policy configuration
func (o *X509PolicyOptions) GetDeniedNameOptions() *X509NameOptions {
if o == nil {
return nil
}
return o.DeniedNames
}
// AreWildcardNamesAllowed returns whether the authority allows
// literal wildcard names to be signed.
func (o *X509PolicyOptions) AreWildcardNamesAllowed() bool {
if o == nil {
return true
}
return o.AllowWildcardNames
}
// SSHPolicyOptionsInterface is an interface for providers of
// SSH user and host name policy configuration.
type SSHPolicyOptionsInterface interface {
GetAllowedUserNameOptions() *SSHNameOptions
GetDeniedUserNameOptions() *SSHNameOptions
GetAllowedHostNameOptions() *SSHNameOptions
GetDeniedHostNameOptions() *SSHNameOptions
}
// SSHPolicyOptions is a container for SSH user and host policy
// configuration
type SSHPolicyOptions struct {
// User contains SSH user certificate options.
User *SSHUserCertificateOptions `json:"user,omitempty"`
// Host contains SSH host certificate options.
Host *SSHHostCertificateOptions `json:"host,omitempty"`
}
// GetAllowedUserNameOptions returns the SSH allowed user name policy
// configuration.
func (o *SSHPolicyOptions) GetAllowedUserNameOptions() *SSHNameOptions {
if o == nil || o.User == nil {
return nil
}
return o.User.AllowedNames
}
// GetDeniedUserNameOptions returns the SSH denied user name policy
// configuration.
func (o *SSHPolicyOptions) GetDeniedUserNameOptions() *SSHNameOptions {
if o == nil || o.User == nil {
return nil
}
return o.User.DeniedNames
}
// GetAllowedHostNameOptions returns the SSH allowed host name policy
// configuration.
func (o *SSHPolicyOptions) GetAllowedHostNameOptions() *SSHNameOptions {
if o == nil || o.Host == nil {
return nil
}
return o.Host.AllowedNames
}
// GetDeniedHostNameOptions returns the SSH denied host name policy
// configuration.
func (o *SSHPolicyOptions) GetDeniedHostNameOptions() *SSHNameOptions {
if o == nil || o.Host == nil {
return nil
}
return o.Host.DeniedNames
}
// SSHUserCertificateOptions is a collection of SSH user certificate options.
type SSHUserCertificateOptions struct {
// AllowedNames contains the names the provisioner is authorized to sign
AllowedNames *SSHNameOptions `json:"allow,omitempty"`
// DeniedNames contains the names the provisioner is not authorized to sign
DeniedNames *SSHNameOptions `json:"deny,omitempty"`
}
// SSHHostCertificateOptions is a collection of SSH host certificate options.
// It's an alias of SSHUserCertificateOptions, as the options are the same
// for both types of certificates.
type SSHHostCertificateOptions SSHUserCertificateOptions
// SSHNameOptions models the SSH name policy configuration.
type SSHNameOptions struct {
DNSDomains []string `json:"dns,omitempty"`
IPRanges []string `json:"ip,omitempty"`
EmailAddresses []string `json:"email,omitempty"`
Principals []string `json:"principal,omitempty"`
}
// GetAllowedNameOptions returns the AllowedSSHNameOptions, which models the
// names that a provisioner is authorized to sign SSH certificates for.
func (o *SSHUserCertificateOptions) GetAllowedNameOptions() *SSHNameOptions {
if o == nil {
return nil
}
return o.AllowedNames
}
// GetDeniedNameOptions returns the DeniedSSHNameOptions, which models the
// names that a provisioner is NOT authorized to sign SSH certificates for.
func (o *SSHUserCertificateOptions) GetDeniedNameOptions() *SSHNameOptions {
if o == nil {
return nil
}
return o.DeniedNames
}
// HasNames checks if the SSHNameOptions has one or more
// names configured.
func (o *SSHNameOptions) HasNames() bool {
return len(o.DNSDomains) > 0 ||
len(o.IPRanges) > 0 ||
len(o.EmailAddresses) > 0 ||
len(o.Principals) > 0
}

@ -0,0 +1,45 @@
package policy
import (
"testing"
)
func TestX509PolicyOptions_IsWildcardLiteralAllowed(t *testing.T) {
tests := []struct {
name string
options *X509PolicyOptions
want bool
}{
{
name: "nil-options",
options: nil,
want: true,
},
{
name: "not-set",
options: &X509PolicyOptions{},
want: false,
},
{
name: "set-true",
options: &X509PolicyOptions{
AllowWildcardNames: true,
},
want: true,
},
{
name: "set-false",
options: &X509PolicyOptions{
AllowWildcardNames: false,
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.options.AreWildcardNamesAllowed(); got != tt.want {
t.Errorf("X509PolicyOptions.IsWildcardLiteralAllowed() = %v, want %v", got, tt.want)
}
})
}
}

@ -0,0 +1,256 @@
package policy
import (
"fmt"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/policy"
)
// X509Policy is an alias for policy.X509NamePolicyEngine
type X509Policy policy.X509NamePolicyEngine
// UserPolicy is an alias for policy.SSHNamePolicyEngine
type UserPolicy policy.SSHNamePolicyEngine
// HostPolicy is an alias for policy.SSHNamePolicyEngine
type HostPolicy policy.SSHNamePolicyEngine
// NewX509PolicyEngine creates a new x509 name policy engine
func NewX509PolicyEngine(policyOptions X509PolicyOptionsInterface) (X509Policy, error) {
// return early if no policy engine options to configure
if policyOptions == nil {
return nil, nil
}
options := []policy.NamePolicyOption{}
allowed := policyOptions.GetAllowedNameOptions()
if allowed != nil && allowed.HasNames() {
options = append(options,
policy.WithPermittedCommonNames(allowed.CommonNames...),
policy.WithPermittedDNSDomains(allowed.DNSDomains...),
policy.WithPermittedIPsOrCIDRs(allowed.IPRanges...),
policy.WithPermittedEmailAddresses(allowed.EmailAddresses...),
policy.WithPermittedURIDomains(allowed.URIDomains...),
)
}
denied := policyOptions.GetDeniedNameOptions()
if denied != nil && denied.HasNames() {
options = append(options,
policy.WithExcludedCommonNames(denied.CommonNames...),
policy.WithExcludedDNSDomains(denied.DNSDomains...),
policy.WithExcludedIPsOrCIDRs(denied.IPRanges...),
policy.WithExcludedEmailAddresses(denied.EmailAddresses...),
policy.WithExcludedURIDomains(denied.URIDomains...),
)
}
// ensure no policy engine is returned when no name options were provided
if len(options) == 0 {
return nil, nil
}
// check if configuration specifies that wildcard names are allowed
if policyOptions.AreWildcardNamesAllowed() {
options = append(options, policy.WithAllowLiteralWildcardNames())
}
// enable subject common name verification by default
options = append(options, policy.WithSubjectCommonNameVerification())
return policy.New(options...)
}
type sshPolicyEngineType string
const (
UserPolicyEngineType sshPolicyEngineType = "user"
HostPolicyEngineType sshPolicyEngineType = "host"
)
// newSSHUserPolicyEngine creates a new SSH user certificate policy engine
func NewSSHUserPolicyEngine(policyOptions SSHPolicyOptionsInterface) (UserPolicy, error) {
policyEngine, err := newSSHPolicyEngine(policyOptions, UserPolicyEngineType)
if err != nil {
return nil, err
}
return policyEngine, nil
}
// newSSHHostPolicyEngine create a new SSH host certificate policy engine
func NewSSHHostPolicyEngine(policyOptions SSHPolicyOptionsInterface) (HostPolicy, error) {
policyEngine, err := newSSHPolicyEngine(policyOptions, HostPolicyEngineType)
if err != nil {
return nil, err
}
return policyEngine, nil
}
// newSSHPolicyEngine creates a new SSH name policy engine
func newSSHPolicyEngine(policyOptions SSHPolicyOptionsInterface, typ sshPolicyEngineType) (policy.SSHNamePolicyEngine, error) {
// return early if no policy engine options to configure
if policyOptions == nil {
return nil, nil
}
var (
allowed *SSHNameOptions
denied *SSHNameOptions
)
switch typ {
case UserPolicyEngineType:
allowed = policyOptions.GetAllowedUserNameOptions()
denied = policyOptions.GetDeniedUserNameOptions()
case HostPolicyEngineType:
allowed = policyOptions.GetAllowedHostNameOptions()
denied = policyOptions.GetDeniedHostNameOptions()
default:
return nil, fmt.Errorf("unknown SSH policy engine type %s provided", typ)
}
options := []policy.NamePolicyOption{}
if allowed != nil && allowed.HasNames() {
options = append(options,
policy.WithPermittedDNSDomains(allowed.DNSDomains...),
policy.WithPermittedIPsOrCIDRs(allowed.IPRanges...),
policy.WithPermittedEmailAddresses(allowed.EmailAddresses...),
policy.WithPermittedPrincipals(allowed.Principals...),
)
}
if denied != nil && denied.HasNames() {
options = append(options,
policy.WithExcludedDNSDomains(denied.DNSDomains...),
policy.WithExcludedIPsOrCIDRs(denied.IPRanges...),
policy.WithExcludedEmailAddresses(denied.EmailAddresses...),
policy.WithExcludedPrincipals(denied.Principals...),
)
}
// ensure no policy engine is returned when no name options were provided
if len(options) == 0 {
return nil, nil
}
return policy.New(options...)
}
func LinkedToCertificates(p *linkedca.Policy) *Options {
// return early
if p == nil {
return nil
}
// return early if x509 nor SSH is set
if p.GetX509() == nil && p.GetSsh() == nil {
return nil
}
opts := &Options{}
// fill x509 policy configuration
if x509 := p.GetX509(); x509 != nil {
opts.X509 = &X509PolicyOptions{}
if allow := x509.GetAllow(); allow != nil {
opts.X509.AllowedNames = &X509NameOptions{}
if allow.Dns != nil {
opts.X509.AllowedNames.DNSDomains = allow.Dns
}
if allow.Ips != nil {
opts.X509.AllowedNames.IPRanges = allow.Ips
}
if allow.Emails != nil {
opts.X509.AllowedNames.EmailAddresses = allow.Emails
}
if allow.Uris != nil {
opts.X509.AllowedNames.URIDomains = allow.Uris
}
if allow.CommonNames != nil {
opts.X509.AllowedNames.CommonNames = allow.CommonNames
}
}
if deny := x509.GetDeny(); deny != nil {
opts.X509.DeniedNames = &X509NameOptions{}
if deny.Dns != nil {
opts.X509.DeniedNames.DNSDomains = deny.Dns
}
if deny.Ips != nil {
opts.X509.DeniedNames.IPRanges = deny.Ips
}
if deny.Emails != nil {
opts.X509.DeniedNames.EmailAddresses = deny.Emails
}
if deny.Uris != nil {
opts.X509.DeniedNames.URIDomains = deny.Uris
}
if deny.CommonNames != nil {
opts.X509.DeniedNames.CommonNames = deny.CommonNames
}
}
opts.X509.AllowWildcardNames = x509.GetAllowWildcardNames()
}
// fill ssh policy configuration
if ssh := p.GetSsh(); ssh != nil {
opts.SSH = &SSHPolicyOptions{}
if host := ssh.GetHost(); host != nil {
opts.SSH.Host = &SSHHostCertificateOptions{}
if allow := host.GetAllow(); allow != nil {
opts.SSH.Host.AllowedNames = &SSHNameOptions{}
if allow.Dns != nil {
opts.SSH.Host.AllowedNames.DNSDomains = allow.Dns
}
if allow.Ips != nil {
opts.SSH.Host.AllowedNames.IPRanges = allow.Ips
}
if allow.Principals != nil {
opts.SSH.Host.AllowedNames.Principals = allow.Principals
}
}
if deny := host.GetDeny(); deny != nil {
opts.SSH.Host.DeniedNames = &SSHNameOptions{}
if deny.Dns != nil {
opts.SSH.Host.DeniedNames.DNSDomains = deny.Dns
}
if deny.Ips != nil {
opts.SSH.Host.DeniedNames.IPRanges = deny.Ips
}
if deny.Principals != nil {
opts.SSH.Host.DeniedNames.Principals = deny.Principals
}
}
}
if user := ssh.GetUser(); user != nil {
opts.SSH.User = &SSHUserCertificateOptions{}
if allow := user.GetAllow(); allow != nil {
opts.SSH.User.AllowedNames = &SSHNameOptions{}
if allow.Emails != nil {
opts.SSH.User.AllowedNames.EmailAddresses = allow.Emails
}
if allow.Principals != nil {
opts.SSH.User.AllowedNames.Principals = allow.Principals
}
}
if deny := user.GetDeny(); deny != nil {
opts.SSH.User.DeniedNames = &SSHNameOptions{}
if deny.Emails != nil {
opts.SSH.User.DeniedNames.EmailAddresses = deny.Emails
}
if deny.Principals != nil {
opts.SSH.User.DeniedNames.Principals = deny.Principals
}
}
}
}
return opts
}

@ -0,0 +1,155 @@
package policy
import (
"testing"
"github.com/google/go-cmp/cmp"
"go.step.sm/linkedca"
)
func TestPolicyToCertificates(t *testing.T) {
type args struct {
policy *linkedca.Policy
}
tests := []struct {
name string
args args
want *Options
}{
{
name: "nil",
args: args{
policy: nil,
},
want: nil,
},
{
name: "no-policy",
args: args{
&linkedca.Policy{},
},
want: nil,
},
{
name: "partial-policy",
args: args{
&linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{
Dns: []string{"*.local"},
},
AllowWildcardNames: false,
},
},
},
want: &Options{
X509: &X509PolicyOptions{
AllowedNames: &X509NameOptions{
DNSDomains: []string{"*.local"},
},
AllowWildcardNames: false,
},
},
},
{
name: "full-policy",
args: args{
&linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{
Dns: []string{"step"},
Ips: []string{"127.0.0.1/24"},
Emails: []string{"*.example.com"},
Uris: []string{"https://*.local"},
CommonNames: []string{"some name"},
},
Deny: &linkedca.X509Names{
Dns: []string{"bad"},
Ips: []string{"127.0.0.30"},
Emails: []string{"badhost.example.com"},
Uris: []string{"https://badhost.local"},
CommonNames: []string{"another name"},
},
AllowWildcardNames: true,
},
Ssh: &linkedca.SSHPolicy{
Host: &linkedca.SSHHostPolicy{
Allow: &linkedca.SSHHostNames{
Dns: []string{"*.localhost"},
Ips: []string{"127.0.0.1/24"},
Principals: []string{"user"},
},
Deny: &linkedca.SSHHostNames{
Dns: []string{"badhost.localhost"},
Ips: []string{"127.0.0.40"},
Principals: []string{"root"},
},
},
User: &linkedca.SSHUserPolicy{
Allow: &linkedca.SSHUserNames{
Emails: []string{"@work"},
Principals: []string{"user"},
},
Deny: &linkedca.SSHUserNames{
Emails: []string{"root@work"},
Principals: []string{"root"},
},
},
},
},
},
want: &Options{
X509: &X509PolicyOptions{
AllowedNames: &X509NameOptions{
DNSDomains: []string{"step"},
IPRanges: []string{"127.0.0.1/24"},
EmailAddresses: []string{"*.example.com"},
URIDomains: []string{"https://*.local"},
CommonNames: []string{"some name"},
},
DeniedNames: &X509NameOptions{
DNSDomains: []string{"bad"},
IPRanges: []string{"127.0.0.30"},
EmailAddresses: []string{"badhost.example.com"},
URIDomains: []string{"https://badhost.local"},
CommonNames: []string{"another name"},
},
AllowWildcardNames: true,
},
SSH: &SSHPolicyOptions{
Host: &SSHHostCertificateOptions{
AllowedNames: &SSHNameOptions{
DNSDomains: []string{"*.localhost"},
IPRanges: []string{"127.0.0.1/24"},
Principals: []string{"user"},
},
DeniedNames: &SSHNameOptions{
DNSDomains: []string{"badhost.localhost"},
IPRanges: []string{"127.0.0.40"},
Principals: []string{"root"},
},
},
User: &SSHUserCertificateOptions{
AllowedNames: &SSHNameOptions{
EmailAddresses: []string{"@work"},
Principals: []string{"user"},
},
DeniedNames: &SSHNameOptions{
EmailAddresses: []string{"root@work"},
Principals: []string{"root"},
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := LinkedToCertificates(tt.args.policy)
if !cmp.Equal(tt.want, got) {
t.Errorf("policyToCertificates() diff=\n%s", cmp.Diff(tt.want, got))
}
})
}
}

File diff suppressed because it is too large Load Diff

@ -3,6 +3,8 @@ package provisioner
import (
"context"
"crypto/x509"
"fmt"
"net"
"time"
"github.com/pkg/errors"
@ -23,7 +25,8 @@ type ACME struct {
RequireEAB bool `json:"requireEAB,omitempty"`
Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"`
ctl *Controller
ctl *Controller
}
// GetID returns the provisioner unique identifier.
@ -71,7 +74,7 @@ func (p *ACME) DefaultTLSCertDuration() time.Duration {
return p.ctl.Claimer.DefaultTLSCertDuration()
}
// Init initializes and validates the fields of a JWK type.
// Init initializes and validates the fields of an ACME type.
func (p *ACME) Init(config Config) (err error) {
switch {
case p.Type == "":
@ -80,15 +83,56 @@ func (p *ACME) Init(config Config) (err error) {
return errors.New("provisioner name cannot be empty")
}
p.ctl, err = NewController(p, p.Claims, config)
p.ctl, err = NewController(p, p.Claims, config, p.Options)
return
}
// ACMEIdentifierType encodes ACME Identifier types
type ACMEIdentifierType string
const (
// IP is the ACME ip identifier type
IP ACMEIdentifierType = "ip"
// DNS is the ACME dns identifier type
DNS ACMEIdentifierType = "dns"
)
// ACMEIdentifier encodes ACME Order Identifiers
type ACMEIdentifier struct {
Type ACMEIdentifierType
Value string
}
// AuthorizeOrderIdentifier verifies the provisioner is allowed to issue a
// certificate for an ACME Order Identifier.
func (p *ACME) AuthorizeOrderIdentifier(ctx context.Context, identifier ACMEIdentifier) error {
x509Policy := p.ctl.getPolicy().getX509()
// identifier is allowed if no policy is configured
if x509Policy == nil {
return nil
}
// assuming only valid identifiers (IP or DNS) are provided
var err error
switch identifier.Type {
case IP:
err = x509Policy.IsIPAllowed(net.ParseIP(identifier.Value))
case DNS:
err = x509Policy.IsDNSAllowed(identifier.Value)
default:
err = fmt.Errorf("invalid ACME identifier type '%s' provided", identifier.Type)
}
return err
}
// AuthorizeSign does not do any validation, because all validation is handled
// in the ACME protocol. This method returns a list of modifiers / constraints
// on the resulting certificate.
func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
return []SignOption{
opts := []SignOption{
p,
// modifiers / withOptions
newProvisionerExtensionOption(TypeACME, p.Name, ""),
@ -97,7 +141,10 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
// validators
defaultPublicKeyValidator{},
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
}, nil
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
}
return opts, nil
}
// AuthorizeRevoke is called just before the certificate is to be revoked by

@ -176,7 +176,7 @@ func TestACME_AuthorizeSign(t *testing.T) {
}
} else {
if assert.Nil(t, tc.err) && assert.NotNil(t, opts) {
assert.Len(t, 6, opts)
assert.Equals(t, 7, len(opts)) // number of SignOptions returned
for _, o := range opts {
switch v := o.(type) {
case *ACME:
@ -193,6 +193,8 @@ func TestACME_AuthorizeSign(t *testing.T) {
case *validityValidator:
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine)
default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}

@ -17,10 +17,12 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/errs"
)
// awsIssuer is the string used as issuer in the generated tokens.
@ -421,7 +423,7 @@ func (p *AWS) Init(config Config) (err error) {
}
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.ctl, err = NewController(p, p.Claims, config)
p.ctl, err = NewController(p, p.Claims, config, p.Options)
return
}
@ -476,6 +478,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
defaultPublicKeyValidator{},
commonNameValidator(payload.Claims.Subject),
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
), nil
}
@ -543,7 +546,7 @@ func (p *AWS) readURL(url string) ([]byte, error) {
if err != nil {
return nil, err
}
return nil, fmt.Errorf("Request for metadata returned non-successful status code %d",
return nil, fmt.Errorf("request for metadata returned non-successful status code %d",
resp.StatusCode)
}
@ -576,7 +579,7 @@ func (p *AWS) readURLv2(url string) (*http.Response, error) {
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("Request for API token returned non-successful status code %d", resp.StatusCode)
return nil, fmt.Errorf("request for API token returned non-successful status code %d", resp.StatusCode)
}
token, err := io.ReadAll(resp.Body)
if err != nil {
@ -754,5 +757,7 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
&sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate
&sshCertDefaultValidator{},
// Ensure that all principal names are allowed
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil),
), nil
}

@ -642,11 +642,11 @@ func TestAWS_AuthorizeSign(t *testing.T) {
code int
wantErr bool
}{
{"ok", p1, args{t1, "foo.local"}, 7, http.StatusOK, false},
{"ok", p2, args{t2, "instance-id"}, 11, http.StatusOK, false},
{"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 11, http.StatusOK, false},
{"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 11, http.StatusOK, false},
{"ok", p1, args{t4, "instance-id"}, 7, http.StatusOK, false},
{"ok", p1, args{t1, "foo.local"}, 8, http.StatusOK, false},
{"ok", p2, args{t2, "instance-id"}, 12, http.StatusOK, false},
{"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 12, http.StatusOK, false},
{"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 12, http.StatusOK, false},
{"ok", p1, args{t4, "instance-id"}, 8, http.StatusOK, false},
{"fail account", p3, args{token: t3}, 0, http.StatusUnauthorized, true},
{"fail token", p1, args{token: "token"}, 0, http.StatusUnauthorized, true},
{"fail subject", p1, args{token: failSubject}, 0, http.StatusUnauthorized, true},
@ -673,7 +673,7 @@ func TestAWS_AuthorizeSign(t *testing.T) {
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
default:
assert.Len(t, tt.wantLen, got)
assert.Equals(t, tt.wantLen, len(got))
for _, o := range got {
switch v := o.(type) {
case *AWS:
@ -699,6 +699,8 @@ func TestAWS_AuthorizeSign(t *testing.T) {
assert.Equals(t, v, nil)
case dnsNamesValidator:
assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"})
case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine)
default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}

@ -13,10 +13,12 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/errs"
)
// azureOIDCBaseURL is the base discovery url for Microsoft Azure tokens.
@ -219,7 +221,7 @@ func (p *Azure) Init(config Config) (err error) {
return
}
p.ctl, err = NewController(p, p.Claims, config)
p.ctl, err = NewController(p, p.Claims, config, p.Options)
return
}
@ -360,6 +362,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
// validators
defaultPublicKeyValidator{},
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
), nil
}
@ -425,6 +428,8 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
&sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate
&sshCertDefaultValidator{},
// Ensure that all principal names are allowed
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil),
), nil
}

@ -474,11 +474,11 @@ func TestAzure_AuthorizeSign(t *testing.T) {
code int
wantErr bool
}{
{"ok", p1, args{t1}, 6, http.StatusOK, false},
{"ok", p2, args{t2}, 11, http.StatusOK, false},
{"ok", p1, args{t11}, 6, http.StatusOK, false},
{"ok", p5, args{t5}, 6, http.StatusOK, false},
{"ok", p7, args{t7}, 6, http.StatusOK, false},
{"ok", p1, args{t1}, 7, http.StatusOK, false},
{"ok", p2, args{t2}, 12, http.StatusOK, false},
{"ok", p1, args{t11}, 7, http.StatusOK, false},
{"ok", p5, args{t5}, 7, http.StatusOK, false},
{"ok", p7, args{t7}, 7, http.StatusOK, false},
{"fail tenant", p3, args{t3}, 0, http.StatusUnauthorized, true},
{"fail resource group", p4, args{t4}, 0, http.StatusUnauthorized, true},
{"fail subscription", p6, args{t6}, 0, http.StatusUnauthorized, true},
@ -502,7 +502,7 @@ func TestAzure_AuthorizeSign(t *testing.T) {
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
default:
assert.Len(t, tt.wantLen, got)
assert.Equals(t, tt.wantLen, len(got))
for _, o := range got {
switch v := o.(type) {
case *Azure:
@ -528,6 +528,8 @@ func TestAzure_AuthorizeSign(t *testing.T) {
assert.Equals(t, v, nil)
case dnsNamesValidator:
assert.Equals(t, []string(v), []string{"virtualMachine"})
case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine)
default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}

@ -21,14 +21,19 @@ type Controller struct {
IdentityFunc GetIdentityFunc
AuthorizeRenewFunc AuthorizeRenewFunc
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
policy *policyEngine
}
// NewController initializes a new provisioner controller.
func NewController(p Interface, claims *Claims, config Config) (*Controller, error) {
func NewController(p Interface, claims *Claims, config Config, options *Options) (*Controller, error) {
claimer, err := NewClaimer(claims, config.Claims)
if err != nil {
return nil, err
}
policy, err := newPolicyEngine(options)
if err != nil {
return nil, err
}
return &Controller{
Interface: p,
Audiences: &config.Audiences,
@ -36,6 +41,7 @@ func NewController(p Interface, claims *Claims, config Config) (*Controller, err
IdentityFunc: config.GetIdentityFunc,
AuthorizeRenewFunc: config.AuthorizeRenewFunc,
AuthorizeSSHRenewFunc: config.AuthorizeSSHRenewFunc,
policy: policy,
}, nil
}
@ -192,3 +198,10 @@ func SanitizeSSHUserPrincipal(email string) string {
}
}, strings.ToLower(email))
}
func (c *Controller) getPolicy() *policyEngine {
if c == nil {
return nil
}
return c.policy
}

@ -9,6 +9,8 @@ import (
"time"
"golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/authority/policy"
)
var trueValue = true
@ -30,11 +32,40 @@ func mustDuration(t *testing.T, s string) *Duration {
return d
}
func mustNewPolicyEngine(t *testing.T, options *Options) *policyEngine {
t.Helper()
c, err := newPolicyEngine(options)
if err != nil {
t.Fatal(err)
}
return c
}
func TestNewController(t *testing.T) {
options := &Options{
X509: &X509Options{
AllowedNames: &policy.X509NameOptions{
DNSDomains: []string{"*.local"},
},
},
SSH: &SSHOptions{
Host: &policy.SSHHostCertificateOptions{
AllowedNames: &policy.SSHNameOptions{
DNSDomains: []string{"*.local"},
},
},
User: &policy.SSHUserCertificateOptions{
AllowedNames: &policy.SSHNameOptions{
EmailAddresses: []string{"@example.com"},
},
},
},
}
type args struct {
p Interface
claims *Claims
config Config
p Interface
claims *Claims
config Config
options *Options
}
tests := []struct {
name string
@ -45,7 +76,7 @@ func TestNewController(t *testing.T) {
{"ok", args{&JWK{}, nil, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}}, &Controller{
}, nil}, &Controller{
Interface: &JWK{},
Audiences: &testAudiences,
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
@ -55,24 +86,49 @@ func TestNewController(t *testing.T) {
}, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}}, &Controller{
}, nil}, &Controller{
Interface: &JWK{},
Audiences: &testAudiences,
Claimer: mustClaimer(t, &Claims{
DisableRenewal: &defaultDisableRenewal,
}, globalProvisionerClaims),
}, false},
{"ok with claims and options", args{&JWK{}, &Claims{
DisableRenewal: &defaultDisableRenewal,
}, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}, options}, &Controller{
Interface: &JWK{},
Audiences: &testAudiences,
Claimer: mustClaimer(t, &Claims{
DisableRenewal: &defaultDisableRenewal,
}, globalProvisionerClaims),
policy: mustNewPolicyEngine(t, options),
}, false},
{"fail claimer", args{&JWK{}, &Claims{
MinTLSDur: mustDuration(t, "24h"),
MaxTLSDur: mustDuration(t, "2h"),
}, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}, nil}, nil, true},
{"fail options", args{&JWK{}, &Claims{
DisableRenewal: &defaultDisableRenewal,
}, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}, &Options{
X509: &X509Options{
AllowedNames: &policy.X509NameOptions{
DNSDomains: []string{"**.local"},
},
},
}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewController(tt.args.p, tt.args.claims, tt.args.config)
got, err := NewController(tt.args.p, tt.args.claims, tt.args.config, tt.args.options)
if (err != nil) != tt.wantErr {
t.Errorf("NewController() error = %v, wantErr %v", err, tt.wantErr)
return

@ -14,10 +14,12 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/errs"
)
// gcpCertsURL is the url that serves Google OAuth2 public keys.
@ -212,7 +214,7 @@ func (p *GCP) Init(config Config) (err error) {
}
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.ctl, err = NewController(p, p.Claims, config)
p.ctl, err = NewController(p, p.Claims, config, p.Options)
return
}
@ -270,6 +272,7 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// validators
defaultPublicKeyValidator{},
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
), nil
}
@ -432,5 +435,7 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
&sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate
&sshCertDefaultValidator{},
// Ensure that all principal names are allowed
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil),
), nil
}

@ -516,9 +516,9 @@ func TestGCP_AuthorizeSign(t *testing.T) {
code int
wantErr bool
}{
{"ok", p1, args{t1}, 6, http.StatusOK, false},
{"ok", p2, args{t2}, 11, http.StatusOK, false},
{"ok", p3, args{t3}, 6, http.StatusOK, false},
{"ok", p1, args{t1}, 7, http.StatusOK, false},
{"ok", p2, args{t2}, 12, http.StatusOK, false},
{"ok", p3, args{t3}, 7, http.StatusOK, false},
{"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true},
{"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true},
{"fail iss", p1, args{failIss}, 0, http.StatusUnauthorized, true},
@ -545,7 +545,7 @@ func TestGCP_AuthorizeSign(t *testing.T) {
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code)
default:
assert.Len(t, tt.wantLen, got)
assert.Equals(t, tt.wantLen, len(got))
for _, o := range got {
switch v := o.(type) {
case *GCP:
@ -571,6 +571,8 @@ func TestGCP_AuthorizeSign(t *testing.T) {
assert.Equals(t, v, nil)
case dnsNamesValidator:
assert.Equals(t, []string(v), []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"})
case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine)
default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}

@ -7,10 +7,12 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/errs"
)
// jwtPayload extends jwt.Claims with step attributes.
@ -97,7 +99,7 @@ func (p *JWK) Init(config Config) (err error) {
return errors.New("provisioner key cannot be empty")
}
p.ctl, err = NewController(p, p.Claims, config)
p.ctl, err = NewController(p, p.Claims, config, p.Options)
return
}
@ -141,6 +143,7 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err
// revoke the certificate with serial number in the `sub` property.
func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.ctl.Audiences.Revoke)
// TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRevoke)
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke")
}
@ -180,6 +183,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
defaultPublicKeyValidator{},
defaultSANsValidator(claims.SANs),
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
}, nil
}
@ -188,6 +192,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// revocation status. Just confirms that the provisioner that created the
// certificate was configured to allow renewals.
func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
// TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRewew and AuthorizeSSHRenew)
return p.ctl.AuthorizeRenew(ctx, cert)
}
@ -260,11 +265,14 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
&sshCertValidityValidator{p.ctl.Claimer},
// Require and validate all the default fields in the SSH certificate.
&sshCertDefaultValidator{},
// Ensure that all principal names are allowed
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), p.ctl.getPolicy().getSSHUser()),
), nil
}
// AuthorizeSSHRevoke returns nil if the token is valid, false otherwise.
func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke)
// TODO(hs): authorize the principals using SSH name policy allow/deny rules (also for other provisioners with AuthorizeSSHRevoke)
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke")
}

@ -297,7 +297,7 @@ func TestJWK_AuthorizeSign(t *testing.T) {
}
} else {
if assert.NotNil(t, got) {
assert.Len(t, 8, got)
assert.Equals(t, 9, len(got))
for _, o := range got {
switch v := o.(type) {
case *JWK:
@ -317,6 +317,8 @@ func TestJWK_AuthorizeSign(t *testing.T) {
assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration())
case defaultSANsValidator:
assert.Equals(t, []string(v), tt.sans)
case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine)
default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}

@ -10,11 +10,13 @@ import (
"net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil"
"go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/errs"
)
// NOTE: There can be at most one kubernetes service account provisioner configured
@ -91,6 +93,7 @@ func (p *K8sSA) GetEncryptedKey() (string, string, bool) {
// Init initializes and validates the fields of a K8sSA type.
func (p *K8sSA) Init(config Config) (err error) {
switch {
case p.Type == "":
return errors.New("provisioner type cannot be empty")
@ -137,7 +140,7 @@ func (p *K8sSA) Init(config Config) (err error) {
p.kauthn = k8s.AuthenticationV1()
*/
p.ctl, err = NewController(p, p.Claims, config)
p.ctl, err = NewController(p, p.Claims, config, p.Options)
return
}
@ -239,6 +242,7 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
// validators
defaultPublicKeyValidator{},
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
}, nil
}
@ -281,6 +285,8 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
&sshCertValidityValidator{p.ctl.Claimer},
// Require and validate all the default fields in the SSH certificate.
&sshCertDefaultValidator{},
// Ensure that all principal names are allowed
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), p.ctl.getPolicy().getSSHUser()),
), nil
}

@ -280,7 +280,6 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
} else {
if assert.Nil(t, tc.err) {
if assert.NotNil(t, opts) {
tot := 0
for _, o := range opts {
switch v := o.(type) {
case *K8sSA:
@ -296,12 +295,13 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
case *validityValidator:
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine)
default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}
tot++
}
assert.Equals(t, tot, 6)
assert.Equals(t, 7, len(opts))
}
}
}
@ -368,7 +368,7 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
} else {
if assert.Nil(t, tc.err) {
if assert.NotNil(t, opts) {
tot := 0
assert.Len(t, 7, opts)
for _, o := range opts {
switch v := o.(type) {
case sshCertificateOptionsFunc:
@ -380,12 +380,13 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
case *sshCertDefaultValidator:
case *sshDefaultDuration:
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
case *sshNamePolicyValidator:
assert.Equals(t, nil, v.userPolicyEngine)
assert.Equals(t, nil, v.hostPolicyEngine)
default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}
tot++
}
assert.Equals(t, tot, 6)
}
}
}

@ -10,12 +10,14 @@ import (
"github.com/pkg/errors"
nebula "github.com/slackhq/nebula/cert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x25519"
"go.step.sm/crypto/x509util"
"golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/errs"
)
const (
@ -61,7 +63,7 @@ func (p *Nebula) Init(config Config) (err error) {
}
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.ctl, err = NewController(p, p.Claims, config)
p.ctl, err = NewController(p, p.Claims, config, p.Options)
return
}
@ -161,6 +163,7 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
},
defaultPublicKeyValidator{},
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
}, nil
}
@ -256,6 +259,8 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti
&sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate
&sshCertDefaultValidator{},
// Ensure that all principal names are allowed
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil),
), nil
}

@ -12,10 +12,12 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/errs"
)
// openIDConfiguration contains the necessary properties in the
@ -195,7 +197,7 @@ func (o *OIDC) Init(config Config) (err error) {
return err
}
o.ctl, err = NewController(o, o.Claims, config)
o.ctl, err = NewController(o, o.Claims, config, o.Options)
return
}
@ -353,6 +355,7 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
// validators
defaultPublicKeyValidator{},
newValidityValidator(o.ctl.Claimer.MinTLSCertDuration(), o.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(o.ctl.getPolicy().getX509()),
}, nil
}
@ -439,6 +442,8 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
&sshCertValidityValidator{o.ctl.Claimer},
// Require all the fields in the SSH certificate
&sshCertDefaultValidator{},
// Ensure that all principal names are allowed
newSSHNamePolicyValidator(o.ctl.getPolicy().getSSHHost(), o.ctl.getPolicy().getSSHUser()),
), nil
}

@ -323,7 +323,7 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got)
} else if assert.NotNil(t, got) {
assert.Len(t, 6, got)
assert.Equals(t, 7, len(got))
for _, o := range got {
switch v := o.(type) {
case *OIDC:
@ -341,6 +341,8 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration())
case emailOnlyIdentity:
assert.Equals(t, string(v), "name@smallstep.com")
case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine)
default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}

@ -5,8 +5,11 @@ import (
"strings"
"github.com/pkg/errors"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/authority/policy"
)
// CertificateOptions is an interface that returns a list of options passed when
@ -56,6 +59,16 @@ type X509Options struct {
// TemplateData is a JSON object with variables that can be used in custom
// templates.
TemplateData json.RawMessage `json:"templateData,omitempty"`
// AllowedNames contains the SANs the provisioner is authorized to sign
AllowedNames *policy.X509NameOptions `json:"-"`
// DeniedNames contains the SANs the provisioner is not authorized to sign
DeniedNames *policy.X509NameOptions `json:"-"`
// AllowWildcardNames indicates if literal wildcard names
// like *.example.com are allowed. Defaults to false.
AllowWildcardNames bool `json:"-"`
}
// HasTemplate returns true if a template is defined in the provisioner options.
@ -63,6 +76,31 @@ func (o *X509Options) HasTemplate() bool {
return o != nil && (o.Template != "" || o.TemplateFile != "")
}
// GetAllowedNameOptions returns the AllowedNames, which models the
// SANs that a provisioner is authorized to sign x509 certificates for.
func (o *X509Options) GetAllowedNameOptions() *policy.X509NameOptions {
if o == nil {
return nil
}
return o.AllowedNames
}
// GetDeniedNameOptions returns the DeniedNames, which models the
// SANs that a provisioner is NOT authorized to sign x509 certificates for.
func (o *X509Options) GetDeniedNameOptions() *policy.X509NameOptions {
if o == nil {
return nil
}
return o.DeniedNames
}
func (o *X509Options) AreWildcardNamesAllowed() bool {
if o == nil {
return true
}
return o.AllowWildcardNames
}
// TemplateOptions generates a CertificateOptions with the template and data
// defined in the ProvisionerOptions, the provisioner generated data, and the
// user data provided in the request. If no template has been provided,

@ -287,3 +287,38 @@ func Test_unsafeParseSigned(t *testing.T) {
})
}
}
func TestX509Options_IsWildcardLiteralAllowed(t *testing.T) {
tests := []struct {
name string
options *X509Options
want bool
}{
{
name: "nil-options",
options: nil,
want: true,
},
{
name: "set-true",
options: &X509Options{
AllowWildcardNames: true,
},
want: true,
},
{
name: "set-false",
options: &X509Options{
AllowWildcardNames: false,
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.options.AreWildcardNamesAllowed(); got != tt.want {
t.Errorf("X509PolicyOptions.IsWildcardLiteralAllowed() = %v, want %v", got, tt.want)
}
})
}
}

@ -0,0 +1,65 @@
package provisioner
import "github.com/smallstep/certificates/authority/policy"
type policyEngine struct {
x509Policy policy.X509Policy
sshHostPolicy policy.HostPolicy
sshUserPolicy policy.UserPolicy
}
func newPolicyEngine(options *Options) (*policyEngine, error) {
if options == nil {
return nil, nil
}
var (
x509Policy policy.X509Policy
sshHostPolicy policy.HostPolicy
sshUserPolicy policy.UserPolicy
err error
)
// Initialize the x509 allow/deny policy engine
if x509Policy, err = policy.NewX509PolicyEngine(options.GetX509Options()); err != nil {
return nil, err
}
// Initialize the SSH allow/deny policy engine for host certificates
if sshHostPolicy, err = policy.NewSSHHostPolicyEngine(options.GetSSHOptions()); err != nil {
return nil, err
}
// Initialize the SSH allow/deny policy engine for user certificates
if sshUserPolicy, err = policy.NewSSHUserPolicyEngine(options.GetSSHOptions()); err != nil {
return nil, err
}
return &policyEngine{
x509Policy: x509Policy,
sshHostPolicy: sshHostPolicy,
sshUserPolicy: sshUserPolicy,
}, nil
}
func (p *policyEngine) getX509() policy.X509Policy {
if p == nil {
return nil
}
return p.x509Policy
}
func (p *policyEngine) getSSHHost() policy.HostPolicy {
if p == nil {
return nil
}
return p.sshHostPolicy
}
func (p *policyEngine) getSSHUser() policy.UserPolicy {
if p == nil {
return nil
}
return p.sshUserPolicy
}

@ -28,13 +28,12 @@ type SCEP struct {
// Numerical identifier for the ContentEncryptionAlgorithm as defined in github.com/mozilla-services/pkcs7
// at https://github.com/mozilla-services/pkcs7/blob/33d05740a3526e382af6395d3513e73d4e66d1cb/encrypt.go#L63
// Defaults to 0, being DES-CBC
EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"`
Options *Options `json:"options,omitempty"`
Claims *Claims `json:"claims,omitempty"`
secretChallengePassword string
encryptionAlgorithm int
ctl *Controller
EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"`
Options *Options `json:"options,omitempty"`
Claims *Claims `json:"claims,omitempty"`
ctl *Controller
secretChallengePassword string
encryptionAlgorithm int
}
// GetID returns the provisioner unique identifier.
@ -84,7 +83,6 @@ func (s *SCEP) DefaultTLSCertDuration() time.Duration {
// Init initializes and validates the fields of a SCEP type.
func (s *SCEP) Init(config Config) (err error) {
switch {
case s.Type == "":
return errors.New("provisioner type cannot be empty")
@ -112,7 +110,7 @@ func (s *SCEP) Init(config Config) (err error) {
// TODO: add other, SCEP specific, options?
s.ctl, err = NewController(s, s.Claims, config)
s.ctl, err = NewController(s, s.Claims, config, s.Options)
return
}
@ -129,6 +127,7 @@ func (s *SCEP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
// validators
newPublicKeyMinimumLengthValidator(s.MinimumPublicKeyLength),
newValidityValidator(s.ctl.Claimer.MinTLSCertDuration(), s.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(s.ctl.getPolicy().getX509()),
}, nil
}

@ -13,9 +13,11 @@ import (
"reflect"
"time"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/errs"
)
// DefaultCertValidity is the default validity for a certificate if none is specified.
@ -402,6 +404,32 @@ func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error {
return nil
}
// x509NamePolicyValidator validates that the certificate (to be signed)
// contains only allowed SANs.
type x509NamePolicyValidator struct {
policyEngine policy.X509Policy
}
// newX509NamePolicyValidator return a new SANs allow/deny validator.
func newX509NamePolicyValidator(engine policy.X509Policy) *x509NamePolicyValidator {
return &x509NamePolicyValidator{
policyEngine: engine,
}
}
// Valid validates that the certificate (to be signed) contains only allowed SANs.
func (v *x509NamePolicyValidator) Valid(cert *x509.Certificate, _ SignOptions) error {
if v.policyEngine == nil {
return nil
}
return v.policyEngine.IsX509CertificateAllowed(cert)
}
// var (
// stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64}
// stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...)
// )
// type stepProvisionerASN1 struct {
// Type int
// Name []byte

@ -4,11 +4,13 @@ import (
"crypto/rsa"
"encoding/binary"
"encoding/json"
"fmt"
"math/big"
"strings"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/keyutil"
"golang.org/x/crypto/ssh"
@ -444,6 +446,53 @@ func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate, o SignSSHOpti
}
}
// sshNamePolicyValidator validates that the certificate (to be signed)
// contains only allowed principals.
type sshNamePolicyValidator struct {
hostPolicyEngine policy.HostPolicy
userPolicyEngine policy.UserPolicy
}
// newSSHNamePolicyValidator return a new SSH allow/deny validator.
func newSSHNamePolicyValidator(host policy.HostPolicy, user policy.UserPolicy) *sshNamePolicyValidator {
return &sshNamePolicyValidator{
hostPolicyEngine: host,
userPolicyEngine: user,
}
}
// Valid validates that the certificate (to be signed) contains only allowed principals.
func (v *sshNamePolicyValidator) Valid(cert *ssh.Certificate, _ SignSSHOptions) error {
if v.hostPolicyEngine == nil && v.userPolicyEngine == nil {
// no policy configured at all; allow anything
return nil
}
// Check the policy type to execute based on type of the certificate.
// We don't allow user certs if only a host policy engine is configured and
// the same for host certs: if only a user policy engine is configured, host
// certs are denied. When both policy engines are configured, the type of
// cert determines which policy engine is used.
switch cert.CertType {
case ssh.HostCert:
// when no host policy engine is configured, but a user policy engine is
// configured, the host certificate is denied.
if v.hostPolicyEngine == nil && v.userPolicyEngine != nil {
return errors.New("SSH host certificate not authorized")
}
return v.hostPolicyEngine.IsSSHCertificateAllowed(cert)
case ssh.UserCert:
// when no user policy engine is configured, but a host policy engine is
// configured, the user certificate is denied.
if v.userPolicyEngine == nil && v.hostPolicyEngine != nil {
return errors.New("SSH user certificate not authorized")
}
return v.userPolicyEngine.IsSSHCertificateAllowed(cert)
default:
return fmt.Errorf("unexpected SSH certificate type %d", cert.CertType) // satisfy return; shouldn't happen
}
}
// sshCertTypeUInt32
func sshCertTypeUInt32(ct string) uint32 {
switch ct {

@ -6,6 +6,8 @@ import (
"github.com/pkg/errors"
"go.step.sm/crypto/sshutil"
"github.com/smallstep/certificates/authority/policy"
)
// SSHCertificateOptions is an interface that returns a list of options passed when
@ -33,6 +35,60 @@ type SSHOptions struct {
// TemplateData is a JSON object with variables that can be used in custom
// templates.
TemplateData json.RawMessage `json:"templateData,omitempty"`
// User contains SSH user certificate options.
User *policy.SSHUserCertificateOptions `json:"-"`
// Host contains SSH host certificate options.
Host *policy.SSHHostCertificateOptions `json:"-"`
}
// GetAllowedUserNameOptions returns the SSHNameOptions that are
// allowed when SSH User certificates are requested.
func (o *SSHOptions) GetAllowedUserNameOptions() *policy.SSHNameOptions {
if o == nil {
return nil
}
if o.User == nil {
return nil
}
return o.User.AllowedNames
}
// GetDeniedUserNameOptions returns the SSHNameOptions that are
// denied when SSH user certificates are requested.
func (o *SSHOptions) GetDeniedUserNameOptions() *policy.SSHNameOptions {
if o == nil {
return nil
}
if o.User == nil {
return nil
}
return o.User.DeniedNames
}
// GetAllowedHostNameOptions returns the SSHNameOptions that are
// allowed when SSH host certificates are requested.
func (o *SSHOptions) GetAllowedHostNameOptions() *policy.SSHNameOptions {
if o == nil {
return nil
}
if o.Host == nil {
return nil
}
return o.Host.AllowedNames
}
// GetDeniedHostNameOptions returns the SSHNameOptions that are
// denied when SSH host certificates are requested.
func (o *SSHOptions) GetDeniedHostNameOptions() *policy.SSHNameOptions {
if o == nil {
return nil
}
if o.Host == nil {
return nil
}
return o.Host.DeniedNames
}
// HasTemplate returns true if a template is defined in the provisioner options.

@ -8,9 +8,11 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"golang.org/x/crypto/ssh"
"go.step.sm/crypto/jose"
"github.com/smallstep/certificates/errs"
)
// sshPOPPayload extends jwt.Claims with step attributes.
@ -95,7 +97,7 @@ func (p *SSHPOP) Init(config Config) (err error) {
p.sshPubKeys = config.SSHKeys
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.ctl, err = NewController(p, p.Claims, config)
p.ctl, err = NewController(p, p.Claims, config, nil)
return
}

@ -184,7 +184,7 @@ func generateJWK() (*JWK, error) {
}
p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences,
})
}, nil)
return p, err
}
@ -219,7 +219,7 @@ func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
}
p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences,
})
}, nil)
return p, err
}
@ -256,7 +256,7 @@ func generateSSHPOP() (*SSHPOP, error) {
}
p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences,
})
}, nil)
return p, err
}
@ -305,7 +305,7 @@ M46l92gdOozT
}
p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences,
})
}, nil)
return p, err
}
@ -343,7 +343,7 @@ func generateOIDC() (*OIDC, error) {
}
p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences,
})
}, nil)
return p, err
}
@ -373,7 +373,7 @@ func generateGCP() (*GCP, error) {
}
p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences.WithFragment("gcp/" + name),
})
}, nil)
return p, err
}
@ -411,7 +411,7 @@ func generateAWS() (*AWS, error) {
}
p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences.WithFragment("aws/" + name),
})
}, nil)
return p, err
}
@ -518,7 +518,7 @@ func generateAWSV1Only() (*AWS, error) {
}
p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences.WithFragment("aws/" + name),
})
}, nil)
return p, err
}
@ -608,7 +608,7 @@ func generateAzure() (*Azure, error) {
}
p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences,
})
}, nil)
return p, err
}

@ -8,10 +8,12 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/errs"
)
// x5cPayload extends jwt.Claims with step attributes.
@ -121,7 +123,7 @@ func (p *X5C) Init(config Config) (err error) {
}
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.ctl, err = NewController(p, p.Claims, config)
p.ctl, err = NewController(p, p.Claims, config, p.Options)
return
}
@ -233,6 +235,7 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
defaultSANsValidator(claims.SANs),
defaultPublicKeyValidator{},
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
}, nil
}
@ -317,5 +320,7 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
&sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate
&sshCertDefaultValidator{},
// Ensure that all principal names are allowed
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), p.ctl.getPolicy().getSSHUser()),
), nil
}

@ -468,7 +468,7 @@ func TestX5C_AuthorizeSign(t *testing.T) {
} else {
if assert.Nil(t, tc.err) {
if assert.NotNil(t, opts) {
assert.Equals(t, len(opts), 8)
assert.Equals(t, 9, len(opts))
for _, o := range opts {
switch v := o.(type) {
case *X5C:
@ -480,7 +480,6 @@ func TestX5C_AuthorizeSign(t *testing.T) {
assert.Len(t, 0, v.KeyValuePairs)
case profileLimitDuration:
assert.Equals(t, v.def, tc.p.ctl.Claimer.DefaultTLSCertDuration())
claims, err := tc.p.authorizeToken(tc.token, tc.p.ctl.Audiences.Sign)
assert.FatalError(t, err)
assert.Equals(t, v.notAfter, claims.chains[0][0].NotAfter)
@ -492,6 +491,8 @@ func TestX5C_AuthorizeSign(t *testing.T) {
case *validityValidator:
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine)
default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}
@ -788,6 +789,9 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter)
case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
case *sshNamePolicyValidator:
assert.Equals(t, nil, v.userPolicyEngine)
assert.Equals(t, nil, v.hostPolicyEngine)
case *sshDefaultPublicKeyValidator, *sshCertDefaultValidator, sshCertificateOptionsFunc:
default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
@ -795,9 +799,9 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
tot++
}
if len(tc.claims.Step.SSH.CertType) > 0 {
assert.Equals(t, tot, 9)
assert.Equals(t, tot, 10)
} else {
assert.Equals(t, tot, 7)
assert.Equals(t, tot, 8)
}
}
}

@ -10,16 +10,19 @@ import (
"os"
"github.com/pkg/errors"
"gopkg.in/square/go-jose.v2/jwt"
"go.step.sm/cli-utils/step"
"go.step.sm/cli-utils/ui"
"go.step.sm/crypto/jose"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"go.step.sm/cli-utils/step"
"go.step.sm/cli-utils/ui"
"go.step.sm/crypto/jose"
"go.step.sm/linkedca"
"gopkg.in/square/go-jose.v2/jwt"
)
// GetEncryptedKey returns the JWE key corresponding to the given kid argument.
@ -170,6 +173,12 @@ func (a *Authority) StoreProvisioner(ctx context.Context, prov *linkedca.Provisi
return admin.WrapErrorISE(err, "error generating provisioner config")
}
adm := linkedca.MustAdminFromContext(ctx)
if err := a.checkProvisionerPolicy(ctx, adm, prov.Name, prov.Policy); err != nil {
return err
}
if err := certProv.Init(provisionerConfig); err != nil {
return admin.WrapError(admin.ErrorBadRequestType, err, "error validating configuration for provisioner %s", prov.Name)
}
@ -215,6 +224,12 @@ func (a *Authority) UpdateProvisioner(ctx context.Context, nu *linkedca.Provisio
return admin.WrapErrorISE(err, "error generating provisioner config")
}
adm := linkedca.MustAdminFromContext(ctx)
if err := a.checkProvisionerPolicy(ctx, adm, nu.Name, nu.Policy); err != nil {
return err
}
if err := certProv.Init(provisionerConfig); err != nil {
return admin.WrapErrorISE(err, "error initializing provisioner %s", nu.Name)
}
@ -427,6 +442,60 @@ func optionsToCertificates(p *linkedca.Provisioner) *provisioner.Options {
ops.SSH.Template = string(p.SshTemplate.Template)
ops.SSH.TemplateData = p.SshTemplate.Data
}
if pol := p.GetPolicy(); pol != nil {
if x := pol.GetX509(); x != nil {
if allow := x.GetAllow(); allow != nil {
ops.X509.AllowedNames = &policy.X509NameOptions{
DNSDomains: allow.Dns,
IPRanges: allow.Ips,
EmailAddresses: allow.Emails,
URIDomains: allow.Uris,
}
}
if deny := x.GetDeny(); deny != nil {
ops.X509.DeniedNames = &policy.X509NameOptions{
DNSDomains: deny.Dns,
IPRanges: deny.Ips,
EmailAddresses: deny.Emails,
URIDomains: deny.Uris,
}
}
}
if ssh := pol.GetSsh(); ssh != nil {
if host := ssh.GetHost(); host != nil {
ops.SSH.Host = &policy.SSHHostCertificateOptions{}
if allow := host.GetAllow(); allow != nil {
ops.SSH.Host.AllowedNames = &policy.SSHNameOptions{
DNSDomains: allow.Dns,
IPRanges: allow.Ips,
Principals: allow.Principals,
}
}
if deny := host.GetDeny(); deny != nil {
ops.SSH.Host.DeniedNames = &policy.SSHNameOptions{
DNSDomains: deny.Dns,
IPRanges: deny.Ips,
Principals: deny.Principals,
}
}
}
if user := ssh.GetUser(); user != nil {
ops.SSH.User = &policy.SSHUserCertificateOptions{}
if allow := user.GetAllow(); allow != nil {
ops.SSH.User.AllowedNames = &policy.SSHNameOptions{
EmailAddresses: allow.Emails,
Principals: allow.Principals,
}
}
if deny := user.GetDeny(); deny != nil {
ops.SSH.User.DeniedNames = &policy.SSHNameOptions{
EmailAddresses: deny.Emails,
Principals: deny.Principals,
}
}
}
}
}
return ops
}

@ -5,18 +5,23 @@ import (
"crypto/rand"
"crypto/x509"
"encoding/binary"
"errors"
"fmt"
"net/http"
"strings"
"time"
"golang.org/x/crypto/ssh"
"go.step.sm/crypto/randutil"
"go.step.sm/crypto/sshutil"
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
policy "github.com/smallstep/certificates/policy"
"github.com/smallstep/certificates/templates"
"go.step.sm/crypto/randutil"
"go.step.sm/crypto/sshutil"
"golang.org/x/crypto/ssh"
)
const (
@ -241,6 +246,23 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
return nil, errs.InternalServer("authority.SignSSH: unexpected ssh certificate type: %d", certTpl.CertType)
}
// Check if authority is allowed to sign the certificate
if err := a.isAllowedToSignSSHCertificate(certTpl); err != nil {
var pe *policy.NamePolicyError
if errors.As(err, &pe) && pe.Reason == policy.NotAllowed {
return nil, &errs.Error{
// NOTE: custom forbidden error, so that denied name is sent to client
// as well as shown in the logs.
Status: http.StatusForbidden,
Err: fmt.Errorf("authority not allowed to sign: %w", err),
Msg: fmt.Sprintf("The request was forbidden by the certificate authority: %s", err.Error()),
}
}
return nil, errs.InternalServerErr(err,
errs.WithMessage("authority.SignSSH: error creating ssh certificate"),
)
}
// Sign certificate.
cert, err := sshutil.CreateCertificate(certTpl, signer)
if err != nil {
@ -261,6 +283,11 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
return cert, nil
}
// isAllowedToSignSSHCertificate checks if the Authority is allowed to sign the SSH certificate.
func (a *Authority) isAllowedToSignSSHCertificate(cert *ssh.Certificate) error {
return a.policyEngine.IsSSHCertificateAllowed(cert)
}
// RenewSSH creates a signed SSH certificate using the old SSH certificate as a template.
func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ssh.Certificate, error) {
if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 {

@ -20,6 +20,7 @@ import (
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/templates"
@ -159,6 +160,14 @@ func TestAuthority_SignSSH(t *testing.T) {
assert.FatalError(t, err)
hostTemplateWithHosts, err := provisioner.TemplateSSHOptions(nil, sshutil.CreateTemplateData(sshutil.HostCert, "key-id", []string{"foo.test.com", "bar.test.com"}))
assert.FatalError(t, err)
userTemplateWithRoot, err := provisioner.TemplateSSHOptions(nil, sshutil.CreateTemplateData(sshutil.UserCert, "key-id", []string{"root"}))
assert.FatalError(t, err)
hostTemplateWithExampleDotCom, err := provisioner.TemplateSSHOptions(nil, sshutil.CreateTemplateData(sshutil.HostCert, "key-id", []string{"example.com"}))
assert.FatalError(t, err)
badUserTemplate, err := provisioner.TemplateSSHOptions(nil, sshutil.CreateTemplateData(sshutil.UserCert, "key-id", []string{"127.0.0.1"}))
assert.FatalError(t, err)
badHostTemplate, err := provisioner.TemplateSSHOptions(nil, sshutil.CreateTemplateData(sshutil.HostCert, "key-id", []string{"host...local"}))
assert.FatalError(t, err)
userCustomTemplate, err := provisioner.TemplateSSHOptions(&provisioner.Options{
SSH: &provisioner.SSHOptions{Template: `{
"type": "{{ .Type }}",
@ -182,11 +191,36 @@ func TestAuthority_SignSSH(t *testing.T) {
}, sshutil.CreateTemplateData(sshutil.UserCert, "key-id", []string{"user"}))
assert.FatalError(t, err)
userPolicyOptions := &policy.Options{
SSH: &policy.SSHPolicyOptions{
User: &policy.SSHUserCertificateOptions{
AllowedNames: &policy.SSHNameOptions{
Principals: []string{"user"},
},
},
},
}
userPolicy, err := policy.New(userPolicyOptions)
assert.FatalError(t, err)
hostPolicyOptions := &policy.Options{
SSH: &policy.SSHPolicyOptions{
Host: &policy.SSHHostCertificateOptions{
AllowedNames: &policy.SSHNameOptions{
DNSDomains: []string{"*.test.com"},
},
},
},
}
hostPolicy, err := policy.New(hostPolicyOptions)
assert.FatalError(t, err)
now := time.Now()
type fields struct {
sshCAUserCertSignKey ssh.Signer
sshCAHostCertSignKey ssh.Signer
policyEngine *policy.Engine
}
type args struct {
key ssh.PublicKey
@ -206,39 +240,48 @@ func TestAuthority_SignSSH(t *testing.T) {
want want
wantErr bool
}{
{"ok-user", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions}}, want{CertType: ssh.UserCert}, false},
{"ok-host", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{hostTemplate, hostOptions}}, want{CertType: ssh.HostCert}, false},
{"ok-user-only", fields{signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions}}, want{CertType: ssh.UserCert}, false},
{"ok-host-only", fields{nil, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{hostTemplate, hostOptions}}, want{CertType: ssh.HostCert}, false},
{"ok-opts-type-user", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "user"}, []provisioner.SignOption{userTemplate}}, want{CertType: ssh.UserCert}, false},
{"ok-opts-type-host", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "host"}, []provisioner.SignOption{hostTemplate}}, want{CertType: ssh.HostCert}, false},
{"ok-opts-principals", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{userTemplateWithUser}}, want{CertType: ssh.UserCert, Principals: []string{"user"}}, false},
{"ok-opts-principals", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "host", Principals: []string{"foo.test.com", "bar.test.com"}}, []provisioner.SignOption{hostTemplateWithHosts}}, want{CertType: ssh.HostCert, Principals: []string{"foo.test.com", "bar.test.com"}}, false},
{"ok-opts-valid-after", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "user", ValidAfter: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{userTemplate}}, want{CertType: ssh.UserCert, ValidAfter: uint64(now.Unix())}, false},
{"ok-opts-valid-before", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "host", ValidBefore: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{hostTemplate}}, want{CertType: ssh.HostCert, ValidBefore: uint64(now.Unix())}, false},
{"ok-cert-validator", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertValidator("")}}, want{CertType: ssh.UserCert}, false},
{"ok-cert-modifier", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertModifier("")}}, want{CertType: ssh.UserCert}, false},
{"ok-opts-validator", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsValidator("")}}, want{CertType: ssh.UserCert}, false},
{"ok-opts-modifier", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsModifier("")}}, want{CertType: ssh.UserCert}, false},
{"ok-custom-template", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userCustomTemplate, userOptions}}, want{CertType: ssh.UserCert, Principals: []string{"user", "admin"}}, false},
{"fail-opts-type", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "foo"}, []provisioner.SignOption{userTemplate}}, want{}, true},
{"fail-cert-validator", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertValidator("an error")}}, want{}, true},
{"fail-cert-modifier", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertModifier("an error")}}, want{}, true},
{"fail-opts-validator", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsValidator("an error")}}, want{}, true},
{"fail-opts-modifier", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsModifier("an error")}}, want{}, true},
{"fail-bad-sign-options", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, "wrong type"}}, want{}, true},
{"fail-no-user-key", fields{nil, signer}, args{pub, provisioner.SignSSHOptions{CertType: "user"}, []provisioner.SignOption{userTemplate}}, want{}, true},
{"fail-no-host-key", fields{signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "host"}, []provisioner.SignOption{hostTemplate}}, want{}, true},
{"fail-bad-type", fields{signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, sshTestModifier{CertType: 100}}}, want{}, true},
{"fail-custom-template", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userFailTemplate, userOptions}}, want{}, true},
{"fail-custom-template-syntax-error-file", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userJSONSyntaxErrorTemplateFile, userOptions}}, want{}, true},
{"fail-custom-template-syntax-value-file", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userJSONValueErrorTemplateFile, userOptions}}, want{}, true},
{"ok-user", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions}}, want{CertType: ssh.UserCert}, false},
{"ok-host", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{hostTemplate, hostOptions}}, want{CertType: ssh.HostCert}, false},
{"ok-user-only", fields{signer, nil, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions}}, want{CertType: ssh.UserCert}, false},
{"ok-host-only", fields{nil, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{hostTemplate, hostOptions}}, want{CertType: ssh.HostCert}, false},
{"ok-opts-type-user", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "user"}, []provisioner.SignOption{userTemplate}}, want{CertType: ssh.UserCert}, false},
{"ok-opts-type-host", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "host"}, []provisioner.SignOption{hostTemplate}}, want{CertType: ssh.HostCert}, false},
{"ok-opts-principals", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{userTemplateWithUser}}, want{CertType: ssh.UserCert, Principals: []string{"user"}}, false},
{"ok-opts-principals", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "host", Principals: []string{"foo.test.com", "bar.test.com"}}, []provisioner.SignOption{hostTemplateWithHosts}}, want{CertType: ssh.HostCert, Principals: []string{"foo.test.com", "bar.test.com"}}, false},
{"ok-opts-valid-after", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "user", ValidAfter: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{userTemplate}}, want{CertType: ssh.UserCert, ValidAfter: uint64(now.Unix())}, false},
{"ok-opts-valid-before", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "host", ValidBefore: provisioner.NewTimeDuration(now)}, []provisioner.SignOption{hostTemplate}}, want{CertType: ssh.HostCert, ValidBefore: uint64(now.Unix())}, false},
{"ok-cert-validator", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertValidator("")}}, want{CertType: ssh.UserCert}, false},
{"ok-cert-modifier", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertModifier("")}}, want{CertType: ssh.UserCert}, false},
{"ok-opts-validator", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsValidator("")}}, want{CertType: ssh.UserCert}, false},
{"ok-opts-modifier", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsModifier("")}}, want{CertType: ssh.UserCert}, false},
{"ok-custom-template", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userCustomTemplate, userOptions}}, want{CertType: ssh.UserCert, Principals: []string{"user", "admin"}}, false},
{"ok-user-policy", fields{signer, signer, userPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{userTemplateWithUser}}, want{CertType: ssh.UserCert, Principals: []string{"user"}}, false},
{"ok-host-policy", fields{signer, signer, hostPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "host", Principals: []string{"foo.test.com", "bar.test.com"}}, []provisioner.SignOption{hostTemplateWithHosts}}, want{CertType: ssh.HostCert, Principals: []string{"foo.test.com", "bar.test.com"}}, false},
{"fail-opts-type", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "foo"}, []provisioner.SignOption{userTemplate}}, want{}, true},
{"fail-cert-validator", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertValidator("an error")}}, want{}, true},
{"fail-cert-modifier", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestCertModifier("an error")}}, want{}, true},
{"fail-opts-validator", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsValidator("an error")}}, want{}, true},
{"fail-opts-modifier", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, sshTestOptionsModifier("an error")}}, want{}, true},
{"fail-bad-sign-options", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions, "wrong type"}}, want{}, true},
{"fail-no-user-key", fields{nil, signer, nil}, args{pub, provisioner.SignSSHOptions{CertType: "user"}, []provisioner.SignOption{userTemplate}}, want{}, true},
{"fail-no-host-key", fields{signer, nil, nil}, args{pub, provisioner.SignSSHOptions{CertType: "host"}, []provisioner.SignOption{hostTemplate}}, want{}, true},
{"fail-bad-type", fields{signer, nil, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, sshTestModifier{CertType: 100}}}, want{}, true},
{"fail-custom-template", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userFailTemplate, userOptions}}, want{}, true},
{"fail-custom-template-syntax-error-file", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userJSONSyntaxErrorTemplateFile, userOptions}}, want{}, true},
{"fail-custom-template-syntax-value-file", fields{signer, signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userJSONValueErrorTemplateFile, userOptions}}, want{}, true},
{"fail-user-policy", fields{signer, signer, userPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"root"}}, []provisioner.SignOption{userTemplateWithRoot}}, want{}, true},
{"fail-user-policy-with-host-cert", fields{signer, signer, userPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "host", Principals: []string{"foo.test.com"}}, []provisioner.SignOption{hostTemplateWithExampleDotCom}}, want{}, true},
{"fail-user-policy-with-bad-user", fields{signer, signer, userPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{badUserTemplate}}, want{}, true},
{"fail-host-policy", fields{signer, signer, hostPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "host", Principals: []string{"example.com"}}, []provisioner.SignOption{hostTemplateWithExampleDotCom}}, want{}, true},
{"fail-host-policy-with-user-cert", fields{signer, signer, hostPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{userTemplateWithUser}}, want{}, true},
{"fail-host-policy-with-bad-host", fields{signer, signer, hostPolicy}, args{pub, provisioner.SignSSHOptions{CertType: "host", Principals: []string{"example.com"}}, []provisioner.SignOption{badHostTemplate}}, want{}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := testAuthority(t)
a.sshCAUserCertSignKey = tt.fields.sshCAUserCertSignKey
a.sshCAHostCertSignKey = tt.fields.sshCAHostCertSignKey
a.policyEngine = tt.fields.policyEngine
got, err := a.SignSSH(context.Background(), tt.args.key, tt.args.opts, tt.args.signOpts...)
if (err != nil) != tt.wantErr {

@ -16,16 +16,19 @@ import (
"time"
"github.com/pkg/errors"
"golang.org/x/crypto/ssh"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/pemutil"
"go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner"
casapi "github.com/smallstep/certificates/cas/apiv1"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/pemutil"
"go.step.sm/crypto/x509util"
"golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/policy"
)
// GetTLSOptions returns the tls options configured.
@ -196,6 +199,25 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
}
}
// Check if authority is allowed to sign the certificate
if err := a.isAllowedToSignX509Certificate(leaf); err != nil {
var pe *policy.NamePolicyError
if errors.As(err, &pe) && pe.Reason == policy.NotAllowed {
return nil, errs.ApplyOptions(&errs.Error{
// NOTE: custom forbidden error, so that denied name is sent to client
// as well as shown in the logs.
Status: http.StatusForbidden,
Err: fmt.Errorf("authority not allowed to sign: %w", err),
Msg: fmt.Sprintf("The request was forbidden by the certificate authority: %s", err.Error()),
}, opts...)
}
return nil, errs.InternalServerErr(err,
errs.WithKeyVal("csr", csr),
errs.WithKeyVal("signOptions", signOpts),
errs.WithMessage("error creating certificate"),
)
}
// Sign certificate
lifetime := leaf.NotAfter.Sub(leaf.NotBefore.Add(signOpts.Backdate))
resp, err := a.x509CAService.CreateCertificate(&casapi.CreateCertificateRequest{
@ -219,6 +241,18 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
return fullchain, nil
}
// isAllowedToSignX509Certificate checks if the Authority is allowed
// to sign the X.509 certificate.
func (a *Authority) isAllowedToSignX509Certificate(cert *x509.Certificate) error {
return a.policyEngine.IsX509CertificateAllowed(cert)
}
// AreSANsAllowed evaluates the provided sans against the
// authority X.509 policy.
func (a *Authority) AreSANsAllowed(ctx context.Context, sans []string) error {
return a.policyEngine.AreSANsAllowed(sans)
}
// Renew creates a new Certificate identical to the old certificate, except
// with a validity window that begins 'now'.
func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) {

@ -27,6 +27,7 @@ import (
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/cas/softcas"
"github.com/smallstep/certificates/db"
@ -511,6 +512,39 @@ ZYtQ9Ot36qc=
code: http.StatusForbidden,
}
},
"fail with policy": func(t *testing.T) *signTest {
csr := getCSR(t, priv)
aa := testAuthority(t)
aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template
aa.db = &db.MockAuthDB{
MStoreCertificate: func(crt *x509.Certificate) error {
fmt.Println(crt.Subject)
assert.Equals(t, crt.Subject.CommonName, "smallstep test")
return nil
},
}
options := &policy.Options{
X509: &policy.X509PolicyOptions{
DeniedNames: &policy.X509NameOptions{
DNSDomains: []string{"test.smallstep.com"},
},
},
}
engine, err := policy.New(options)
assert.FatalError(t, err)
aa.policyEngine = engine
return &signTest{
auth: aa,
csr: csr,
extraOpts: extraOpts,
signOpts: signOpts,
notBefore: signOpts.NotBefore.Time().Truncate(time.Second),
notAfter: signOpts.NotAfter.Time().Truncate(time.Second),
extensionsCount: 6,
err: errors.New("authority not allowed to sign"),
code: http.StatusForbidden,
}
},
"ok": func(t *testing.T) *signTest {
csr := getCSR(t, priv)
_a := testAuthority(t)
@ -653,6 +687,38 @@ ZYtQ9Ot36qc=
extensionsCount: 7,
}
},
"ok with policy": func(t *testing.T) *signTest {
csr := getCSR(t, priv)
aa := testAuthority(t)
aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template
aa.db = &db.MockAuthDB{
MStoreCertificate: func(crt *x509.Certificate) error {
fmt.Println(crt.Subject)
assert.Equals(t, crt.Subject.CommonName, "smallstep test")
return nil
},
}
options := &policy.Options{
X509: &policy.X509PolicyOptions{
AllowedNames: &policy.X509NameOptions{
CommonNames: []string{"smallstep test"},
DNSDomains: []string{"*.smallstep.com"},
},
},
}
engine, err := policy.New(options)
assert.FatalError(t, err)
aa.policyEngine = engine
return &signTest{
auth: aa,
csr: csr,
extraOpts: extraOpts,
signOpts: signOpts,
notBefore: signOpts.NotBefore.Time().Truncate(time.Second),
notAfter: signOpts.NotAfter.Time().Truncate(time.Second),
extensionsCount: 6,
}
},
}
for name, genTestCase := range tests {

@ -4,6 +4,7 @@ import (
"bytes"
"crypto/x509"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
@ -12,15 +13,17 @@ import (
"time"
"github.com/pkg/errors"
adminAPI "github.com/smallstep/certificates/authority/admin/api"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"google.golang.org/protobuf/encoding/protojson"
"go.step.sm/cli-utils/token"
"go.step.sm/cli-utils/token/provision"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/randutil"
"go.step.sm/linkedca"
"google.golang.org/protobuf/encoding/protojson"
adminAPI "github.com/smallstep/certificates/authority/admin/api"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
)
const (
@ -687,6 +690,418 @@ retry:
return nil
}
func (c *AdminClient) GetAuthorityPolicy() (*linkedca.Policy, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "policy")})
tok, err := c.generateAdminToken(u)
if err != nil {
return nil, fmt.Errorf("error generating admin token: %w", err)
}
req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody)
if err != nil {
return nil, fmt.Errorf("creating GET %s request failed: %w", u, err)
}
req.Header.Add("Authorization", tok)
retry:
resp, err := c.client.Do(req)
if err != nil {
return nil, fmt.Errorf("client GET %s failed: %w", u, err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readAdminError(resp.Body)
}
var policy = new(linkedca.Policy)
if err := readProtoJSON(resp.Body, policy); err != nil {
return nil, fmt.Errorf("error reading %s: %w", u, err)
}
return policy, nil
}
func (c *AdminClient) CreateAuthorityPolicy(p *linkedca.Policy) (*linkedca.Policy, error) {
var retried bool
body, err := protojson.Marshal(p)
if err != nil {
return nil, fmt.Errorf("error marshaling request: %w", err)
}
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "policy")})
tok, err := c.generateAdminToken(u)
if err != nil {
return nil, fmt.Errorf("error generating admin token: %w", err)
}
req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("creating POST %s request failed: %w", u, err)
}
req.Header.Add("Authorization", tok)
retry:
resp, err := c.client.Do(req)
if err != nil {
return nil, fmt.Errorf("client POST %s failed: %w", u, err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readAdminError(resp.Body)
}
var policy = new(linkedca.Policy)
if err := readProtoJSON(resp.Body, policy); err != nil {
return nil, fmt.Errorf("error reading %s: %w", u, err)
}
return policy, nil
}
func (c *AdminClient) UpdateAuthorityPolicy(p *linkedca.Policy) (*linkedca.Policy, error) {
var retried bool
body, err := protojson.Marshal(p)
if err != nil {
return nil, fmt.Errorf("error marshaling request: %w", err)
}
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "policy")})
tok, err := c.generateAdminToken(u)
if err != nil {
return nil, fmt.Errorf("error generating admin token: %w", err)
}
req, err := http.NewRequest(http.MethodPut, u.String(), bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("creating PUT %s request failed: %w", u, err)
}
req.Header.Add("Authorization", tok)
retry:
resp, err := c.client.Do(req)
if err != nil {
return nil, fmt.Errorf("client PUT %s failed: %w", u, err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readAdminError(resp.Body)
}
var policy = new(linkedca.Policy)
if err := readProtoJSON(resp.Body, policy); err != nil {
return nil, fmt.Errorf("error reading %s: %w", u, err)
}
return policy, nil
}
func (c *AdminClient) RemoveAuthorityPolicy() error {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "policy")})
tok, err := c.generateAdminToken(u)
if err != nil {
return fmt.Errorf("error generating admin token: %w", err)
}
req, err := http.NewRequest(http.MethodDelete, u.String(), http.NoBody)
if err != nil {
return fmt.Errorf("creating DELETE %s request failed: %w", u, err)
}
req.Header.Add("Authorization", tok)
retry:
resp, err := c.client.Do(req)
if err != nil {
return fmt.Errorf("client DELETE %s failed: %w", u, err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return readAdminError(resp.Body)
}
return nil
}
func (c *AdminClient) GetProvisionerPolicy(provisionerName string) (*linkedca.Policy, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", provisionerName, "policy")})
tok, err := c.generateAdminToken(u)
if err != nil {
return nil, fmt.Errorf("error generating admin token: %w", err)
}
req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody)
if err != nil {
return nil, fmt.Errorf("creating GET %s request failed: %w", u, err)
}
req.Header.Add("Authorization", tok)
retry:
resp, err := c.client.Do(req)
if err != nil {
return nil, fmt.Errorf("client GET %s failed: %w", u, err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readAdminError(resp.Body)
}
var policy = new(linkedca.Policy)
if err := readProtoJSON(resp.Body, policy); err != nil {
return nil, fmt.Errorf("error reading %s: %w", u, err)
}
return policy, nil
}
func (c *AdminClient) CreateProvisionerPolicy(provisionerName string, p *linkedca.Policy) (*linkedca.Policy, error) {
var retried bool
body, err := protojson.Marshal(p)
if err != nil {
return nil, fmt.Errorf("error marshaling request: %w", err)
}
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", provisionerName, "policy")})
tok, err := c.generateAdminToken(u)
if err != nil {
return nil, fmt.Errorf("error generating admin token: %w", err)
}
req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("creating POST %s request failed: %w", u, err)
}
req.Header.Add("Authorization", tok)
retry:
resp, err := c.client.Do(req)
if err != nil {
return nil, fmt.Errorf("client POST %s failed: %w", u, err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readAdminError(resp.Body)
}
var policy = new(linkedca.Policy)
if err := readProtoJSON(resp.Body, policy); err != nil {
return nil, fmt.Errorf("error reading %s: %w", u, err)
}
return policy, nil
}
func (c *AdminClient) UpdateProvisionerPolicy(provisionerName string, p *linkedca.Policy) (*linkedca.Policy, error) {
var retried bool
body, err := protojson.Marshal(p)
if err != nil {
return nil, fmt.Errorf("error marshaling request: %w", err)
}
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", provisionerName, "policy")})
tok, err := c.generateAdminToken(u)
if err != nil {
return nil, fmt.Errorf("error generating admin token: %w", err)
}
req, err := http.NewRequest(http.MethodPut, u.String(), bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("creating PUT %s request failed: %w", u, err)
}
req.Header.Add("Authorization", tok)
retry:
resp, err := c.client.Do(req)
if err != nil {
return nil, fmt.Errorf("client PUT %s failed: %w", u, err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readAdminError(resp.Body)
}
var policy = new(linkedca.Policy)
if err := readProtoJSON(resp.Body, policy); err != nil {
return nil, fmt.Errorf("error reading %s: %w", u, err)
}
return policy, nil
}
func (c *AdminClient) RemoveProvisionerPolicy(provisionerName string) error {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", provisionerName, "policy")})
tok, err := c.generateAdminToken(u)
if err != nil {
return fmt.Errorf("error generating admin token: %w", err)
}
req, err := http.NewRequest(http.MethodDelete, u.String(), http.NoBody)
if err != nil {
return fmt.Errorf("creating DELETE %s request failed: %w", u, err)
}
req.Header.Add("Authorization", tok)
retry:
resp, err := c.client.Do(req)
if err != nil {
return fmt.Errorf("client DELETE %s failed: %w", u, err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return readAdminError(resp.Body)
}
return nil
}
func (c *AdminClient) GetACMEPolicy(provisionerName, reference, keyID string) (*linkedca.Policy, error) {
var retried bool
var urlPath string
switch {
case keyID != "":
urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "key", keyID)
default:
urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "reference", reference)
}
u := c.endpoint.ResolveReference(&url.URL{Path: urlPath})
tok, err := c.generateAdminToken(u)
if err != nil {
return nil, fmt.Errorf("error generating admin token: %w", err)
}
req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody)
if err != nil {
return nil, fmt.Errorf("creating GET %s request failed: %w", u, err)
}
req.Header.Add("Authorization", tok)
retry:
resp, err := c.client.Do(req)
if err != nil {
return nil, fmt.Errorf("client GET %s failed: %w", u, err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readAdminError(resp.Body)
}
var policy = new(linkedca.Policy)
if err := readProtoJSON(resp.Body, policy); err != nil {
return nil, fmt.Errorf("error reading %s: %w", u, err)
}
return policy, nil
}
func (c *AdminClient) CreateACMEPolicy(provisionerName, reference, keyID string, p *linkedca.Policy) (*linkedca.Policy, error) {
var retried bool
body, err := protojson.Marshal(p)
if err != nil {
return nil, fmt.Errorf("error marshaling request: %w", err)
}
var urlPath string
switch {
case keyID != "":
urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "key", keyID)
default:
urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "reference", reference)
}
u := c.endpoint.ResolveReference(&url.URL{Path: urlPath})
tok, err := c.generateAdminToken(u)
if err != nil {
return nil, fmt.Errorf("error generating admin token: %w", err)
}
req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("creating POST %s request failed: %w", u, err)
}
req.Header.Add("Authorization", tok)
retry:
resp, err := c.client.Do(req)
if err != nil {
return nil, fmt.Errorf("client POST %s failed: %w", u, err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readAdminError(resp.Body)
}
var policy = new(linkedca.Policy)
if err := readProtoJSON(resp.Body, policy); err != nil {
return nil, fmt.Errorf("error reading %s: %w", u, err)
}
return policy, nil
}
func (c *AdminClient) UpdateACMEPolicy(provisionerName, reference, keyID string, p *linkedca.Policy) (*linkedca.Policy, error) {
var retried bool
body, err := protojson.Marshal(p)
if err != nil {
return nil, fmt.Errorf("error marshaling request: %w", err)
}
var urlPath string
switch {
case keyID != "":
urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "key", keyID)
default:
urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "reference", reference)
}
u := c.endpoint.ResolveReference(&url.URL{Path: urlPath})
tok, err := c.generateAdminToken(u)
if err != nil {
return nil, fmt.Errorf("error generating admin token: %w", err)
}
req, err := http.NewRequest(http.MethodPut, u.String(), bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("creating PUT %s request failed: %w", u, err)
}
req.Header.Add("Authorization", tok)
retry:
resp, err := c.client.Do(req)
if err != nil {
return nil, fmt.Errorf("client PUT %s failed: %w", u, err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readAdminError(resp.Body)
}
var policy = new(linkedca.Policy)
if err := readProtoJSON(resp.Body, policy); err != nil {
return nil, fmt.Errorf("error reading %s: %w", u, err)
}
return policy, nil
}
func (c *AdminClient) RemoveACMEPolicy(provisionerName, reference, keyID string) error {
var retried bool
var urlPath string
switch {
case keyID != "":
urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "key", keyID)
default:
urlPath = path.Join(adminURLPrefix, "acme", "policy", provisionerName, "reference", reference)
}
u := c.endpoint.ResolveReference(&url.URL{Path: urlPath})
tok, err := c.generateAdminToken(u)
if err != nil {
return fmt.Errorf("error generating admin token: %w", err)
}
req, err := http.NewRequest(http.MethodDelete, u.String(), http.NoBody)
if err != nil {
return fmt.Errorf("creating DELETE %s request failed: %w", u, err)
}
req.Header.Add("Authorization", tok)
retry:
resp, err := c.client.Do(req)
if err != nil {
return fmt.Errorf("client DELETE %s failed: %w", u, err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return readAdminError(resp.Body)
}
return nil
}
func readAdminError(r io.ReadCloser) error {
// TODO: not all errors can be read (i.e. 404); seems to be a bigger issue
defer r.Close()

@ -213,8 +213,9 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
adminDB := auth.GetAdminDatabase()
if adminDB != nil {
acmeAdminResponder := adminAPI.NewACMEAdminResponder()
policyAdminResponder := adminAPI.NewPolicyAdminResponder(auth, adminDB, acmeDB)
mux.Route("/admin", func(r chi.Router) {
adminAPI.Route(r, acmeAdminResponder)
adminAPI.Route(r, acmeAdminResponder, policyAdminResponder)
})
}
}

@ -3,15 +3,15 @@ FROM golang:alpine AS builder
WORKDIR /src
COPY . .
RUN apk add --no-cache \
curl \
git \
make && \
make V=1 bin/step-ca
RUN apk add --no-cache curl git make
RUN make V=1 bin/step-ca bin/step-awskms-init bin/step-cloudkms-init
FROM smallstep/step-cli:latest
COPY --from=builder /src/bin/step-ca /usr/local/bin/step-ca
COPY --from=builder /src/bin/step-awskms-init /usr/local/bin/step-awskms-init
COPY --from=builder /src/bin/step-cloudkms-init /usr/local/bin/step-cloudkms-init
USER root
RUN apk add --no-cache libcap && setcap CAP_NET_BIND_SERVICE=+eip /usr/local/bin/step-ca

@ -0,0 +1,34 @@
FROM golang:alpine AS builder
WORKDIR /src
COPY . .
RUN apk add --no-cache curl git make
RUN apk add --no-cache gcc musl-dev pkgconf pcsc-lite-dev
RUN make V=1 GOFLAGS="" build
FROM smallstep/step-cli:latest
COPY --from=builder /src/bin/step-ca /usr/local/bin/step-ca
COPY --from=builder /src/bin/step-awskms-init /usr/local/bin/step-awskms-init
COPY --from=builder /src/bin/step-cloudkms-init /usr/local/bin/step-cloudkms-init
COPY --from=builder /src/bin/step-pkcs11-init /usr/local/bin/step-pkcs11-init
COPY --from=builder /src/bin/step-yubikey-init /usr/local/bin/step-yubikey-init
USER root
RUN apk add --no-cache libcap && setcap CAP_NET_BIND_SERVICE=+eip /usr/local/bin/step-ca
RUN apk add --no-cache pcsc-lite pcsc-lite-libs
USER step
ENV CONFIGPATH="/home/step/config/ca.json"
ENV PWDPATH="/home/step/secrets/password"
VOLUME ["/home/step"]
STOPSIGNAL SIGTERM
HEALTHCHECK CMD step ca health 2>/dev/null | grep "^ok" >/dev/null
COPY docker/entrypoint.sh /entrypoint.sh
ENTRYPOINT ["/bin/bash", "/entrypoint.sh"]
CMD exec /usr/local/bin/step-ca --password-file $PWDPATH $CONFIGPATH

@ -53,6 +53,10 @@ function step_ca_init () {
mv $STEPPATH/password $PWDPATH
}
if [ -f /usr/sbin/pcscd ]; then
/usr/sbin/pcscd
fi
if [ ! -f "${STEPPATH}/config/ca.json" ]; then
init_if_possible
fi

@ -14,20 +14,26 @@ require (
github.com/Azure/go-autorest/autorest/validation v0.3.1 // indirect
github.com/Masterminds/sprig/v3 v3.2.2
github.com/ThalesIgnite/crypto11 v1.2.4
github.com/aws/aws-sdk-go v1.30.29
github.com/aws/aws-sdk-go v1.37.0
github.com/dgraph-io/ristretto v0.0.4-0.20200906165740-41ebdbffecfd // indirect
github.com/fatih/color v1.9.0 // indirect
github.com/form3tech-oss/jwt-go v3.2.3+incompatible // indirect
github.com/go-chi/chi v4.0.2+incompatible
github.com/go-kit/kit v0.10.0 // indirect
github.com/go-piv/piv-go v1.7.0
github.com/go-sql-driver/mysql v1.6.0 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/mock v1.6.0
github.com/google/go-cmp v0.5.7
github.com/google/uuid v1.3.0
github.com/googleapis/gax-go/v2 v2.1.1
github.com/hashicorp/vault/api v1.3.1
github.com/hashicorp/vault/api/auth/approle v0.1.1
github.com/jhump/protoreflect v1.9.0 // indirect
github.com/mattn/go-colorable v0.1.8 // indirect
github.com/mattn/go-isatty v0.0.13 // indirect
github.com/micromdm/scep/v2 v2.1.0
github.com/miekg/pkcs11 v1.0.3 // indirect
github.com/newrelic/go-agent v2.15.0+incompatible
github.com/pkg/errors v0.9.1
github.com/rs/xid v1.2.1
@ -37,17 +43,21 @@ require (
github.com/smallstep/nosql v0.4.0
github.com/stretchr/testify v1.7.1
github.com/urfave/cli v1.22.4
go.etcd.io/bbolt v1.3.6 // indirect
go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352
go.step.sm/cli-utils v0.7.0
go.step.sm/crypto v0.16.1
go.step.sm/linkedca v0.15.0
go.step.sm/linkedca v0.16.1
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd
golang.org/x/net v0.0.0-20220403103023-749bd193bc2b
golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64 // indirect
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect
google.golang.org/api v0.70.0
google.golang.org/genproto v0.0.0-20220222213610-43724f9ea8cf
google.golang.org/grpc v1.44.0
google.golang.org/protobuf v1.27.1
google.golang.org/genproto v0.0.0-20220401170504-314d38edb7de
google.golang.org/grpc v1.45.0
google.golang.org/protobuf v1.28.0
gopkg.in/square/go-jose.v2 v2.6.0
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
)
// replace github.com/smallstep/nosql => ../nosql

@ -143,8 +143,8 @@ github.com/aryann/difflib v0.0.0-20170710044230-e206f873d14a/go.mod h1:DAHtR1m6l
github.com/aryann/difflib v0.0.0-20170710044230-e206f873d14a/go.mod h1:DAHtR1m6lCRdSC2Tm3DSWRPvIPr6xNKyeHdqDQSQT+A=
github.com/aws/aws-lambda-go v1.13.3/go.mod h1:4UKl9IzQMoD+QF79YdCuzCwp8VbmG4VAQwij/eHl5CU=
github.com/aws/aws-sdk-go v1.27.0/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
github.com/aws/aws-sdk-go v1.30.29 h1:NXNqBS9hjOCpDL8SyCyl38gZX3LLLunKOJc5E7vJ8P0=
github.com/aws/aws-sdk-go v1.30.29/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0=
github.com/aws/aws-sdk-go v1.37.0 h1:GzFnhOIsrGyQ69s7VgqtrG2BG8v7X7vwB3Xpbd/DBBk=
github.com/aws/aws-sdk-go v1.37.0/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro=
github.com/aws/aws-sdk-go-v2 v0.18.0/go.mod h1:JWVYvqSMppoMJC0x5wdwiImzgXTI9FuZwxzkQq9wy+g=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
@ -235,13 +235,15 @@ github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.m
github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/evanphx/json-patch/v5 v5.5.0/go.mod h1:G79N1coSVB93tBe7j6PhzjmR3/2VvlbKOFpnXhI9Bw4=
github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys=
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s=
github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU=
github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo=
github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M=
github.com/flynn/noise v1.0.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk=
github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k=
github.com/form3tech-oss/jwt-go v3.2.3+incompatible h1:7ZaBxOI7TMoYBfyA3cQHErNNyAWIKUMIwqxEtgHOs5c=
github.com/form3tech-oss/jwt-go v3.2.3+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k=
github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db/go.mod h1:7dvUGVsVBjqR7JHJk0brhHOZYGmfBYOrK0ZhYMEtBr4=
github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2rbfLwlschooIH4+wKKDR4Pdxhh+TRoA20=
github.com/frankban/quicktest v1.10.0/go.mod h1:ui7WezCLWMWxVWr1GETZY3smRy0G4KWq9vcPtJmFl7Y=
@ -269,8 +271,9 @@ github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG
github.com/go-piv/piv-go v1.7.0 h1:rfjdFdASfGV5KLJhSjgpGJ5lzVZVtRWn8ovy/H9HQ/U=
github.com/go-piv/piv-go v1.7.0/go.mod h1:ON2WvQncm7dIkCQ7kYJs+nc3V4jHGfrrJnSF8HKy7Gk=
github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-stack/stack v1.6.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
@ -287,8 +290,9 @@ github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfU
github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y=
@ -369,6 +373,7 @@ github.com/googleapis/gax-go/v2 v2.1.0/go.mod h1:Q3nei7sK6ybPYH7twZdmQpAd1MKb7pf
github.com/googleapis/gax-go/v2 v2.1.1 h1:dp3bWCh+PPO1zjRRiCSczJav13sBvG4UhNyVTa1KqdU=
github.com/googleapis/gax-go/v2 v2.1.1/go.mod h1:hddJymUZASv3XPyGkUpKj8pPO47Rmb0eJc8R6ouapiM=
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gordonklaus/ineffassign v0.0.0-20200309095847-7953dde2c7bf/go.mod h1:cuNKsD1zp2v6XfE/orVX2QE1LC+i254ceGcVeDT3pTU=
github.com/gorilla/context v0.0.0-20160226214623-1ea25387ff6f/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/mux v1.4.0/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
@ -511,11 +516,14 @@ github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0f
github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/jackc/puddle v1.2.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/jhump/protoreflect v1.6.0 h1:h5jfMVslIg6l29nsMs0D8Wj17RDVdNYti0vDN/PZZoE=
github.com/jhump/protoreflect v1.6.0/go.mod h1:eaTn3RZAmMBcV0fifFvlm6VHNz3wSkYyXYWUh7ymB74=
github.com/jhump/protoreflect v1.9.0 h1:npqHz788dryJiR/l6K/RUQAyh2SwV91+d1dnh4RjO9w=
github.com/jhump/protoreflect v1.9.0/go.mod h1:7GcYQDdMU/O/BBrl/cX6PNHpXh6cenjd8pneu5yW7Tg=
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
github.com/jmespath/go-jmespath v0.3.0 h1:OS12ieG61fsCg5+qLJ+SsW9NicxNkg3b25OyT2yCeUc=
github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo=
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
@ -585,8 +593,9 @@ github.com/micromdm/scep/v2 v2.1.0 h1:2fS9Rla7qRR266hvUoEauBJ7J6FhgssEiq2OkSKXma
github.com/micromdm/scep/v2 v2.1.0/go.mod h1:BkF7TkPPhmgJAMtHfP+sFTKXmgzNJgLQlvvGoOExBcc=
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
github.com/miekg/dns v1.1.43/go.mod h1:+evo5L0630/F6ca/Z9+GAqzhjGyn8/c+TBaOyfEl0V4=
github.com/miekg/pkcs11 v1.0.3-0.20190429190417-a667d056470f h1:eVB9ELsoq5ouItQBr5Tj334bhPJG/MX+m7rTchmzVUQ=
github.com/miekg/pkcs11 v1.0.3-0.20190429190417-a667d056470f/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
github.com/miekg/pkcs11 v1.0.3 h1:iMwmD7I5225wv84WxIG/bmxz9AXjWvTWIbM/TYHvWtw=
github.com/miekg/pkcs11 v1.0.3/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc=
github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw=
github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw=
@ -624,6 +633,7 @@ github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OS
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f/go.mod h1:nwPd6pDNId/Xi16qtKrFHrauSwMNuvk+zcjk89wrnlA=
github.com/newrelic/go-agent v2.15.0+incompatible h1:IB0Fy+dClpBq9aEoIrLyQXzU34JyI1xVTanPLB/+jvU=
github.com/newrelic/go-agent v2.15.0+incompatible/go.mod h1:a8Fv1b/fYhFSReoTU6HDkTYIMZeSVNffmoS726Y0LzQ=
github.com/nishanths/predeclared v0.0.0-20200524104333-86fad755b4d3/go.mod h1:nt3d53pc1VYcphSCIaYAJtnPYnr3Zyn8fMq2wvPGPso=
github.com/oklog/oklog v0.3.2/go.mod h1:FCV+B7mhrz4o+ueLpx+KqkyXRGMWOYEvfiXtdGtbWGs=
github.com/oklog/run v1.0.0 h1:Ru7dDtJNOyC66gQ5dQmaCa0qIsAUFY3sFpK1Xk8igrw=
github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA=
@ -782,8 +792,9 @@ github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1
github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0=
go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
go.etcd.io/bbolt v1.3.6 h1:/ecaJf0sk1l4l6V4awd65v2C3ILy7MSj+s/x1ADCIMU=
go.etcd.io/bbolt v1.3.6/go.mod h1:qXsaaIqmgQH0T+OPdb99Bf+PKfBBQVAdyD6TY9G8XM4=
go.etcd.io/etcd v0.0.0-20191023171146-3cf2f69b5738/go.mod h1:dnLIgRNXwCJa5e+c6mIZCrds/GIG4ncV9HhK5PX7jPg=
go.mozilla.org/pkcs7 v0.0.0-20210730143726-725912489c62/go.mod h1:SNgMg+EgDFwmvSmLRTNKC5fegJjB7v23qTQ0XLGUNHk=
go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 h1:CCriYyAfq1Br1aIYettdHZTy8mBTIPo7We18TuO/bak=
@ -804,8 +815,10 @@ go.step.sm/cli-utils v0.7.0/go.mod h1:Ur6bqA/yl636kCUJbp30J7Unv5JJ226eW2KqXPDwF/
go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0=
go.step.sm/crypto v0.16.1 h1:4mnZk21cSxyMGxsEpJwZKKvJvDu1PN09UVrWWFNUBdk=
go.step.sm/crypto v0.16.1/go.mod h1:3G0yQr5lQqfEG0CMYz8apC/qMtjLRQlzflL2AxkcN+g=
go.step.sm/linkedca v0.15.0 h1:lEkGRDY+u7FudGKt8yEo7nBy5OzceO9s3rl+/sZVL5M=
go.step.sm/linkedca v0.15.0/go.mod h1:W59ucS4vFpuR0g4PtkGbbtXAwxbDEnNCg+ovkej1ANM=
go.step.sm/linkedca v0.16.0 h1:9xdE150lRTEoBq1gJl+prePpSmNqXRXsez3qzRs3Lic=
go.step.sm/linkedca v0.16.0/go.mod h1:W59ucS4vFpuR0g4PtkGbbtXAwxbDEnNCg+ovkej1ANM=
go.step.sm/linkedca v0.16.1 h1:CdbMV5SjnlRsgeYTXaaZmQCkYIgJq8BOzpewri57M2k=
go.step.sm/linkedca v0.16.1/go.mod h1:W59ucS4vFpuR0g4PtkGbbtXAwxbDEnNCg+ovkej1ANM=
go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
@ -927,8 +940,9 @@ golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qx
golang.org/x/net v0.0.0-20211020060615-d418f374d309/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211216030914-fe4d6282115f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd h1:O7DYs+zxREGLKzKoMQrtrEacpb0ZVXA5rIwylE2Xchk=
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220403103023-749bd193bc2b h1:vI32FkLJNAWtGD4BwkThwEy6XS7ZLLMHkSkYfF8M0W0=
golang.org/x/net v0.0.0-20220403103023-749bd193bc2b/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@ -1007,6 +1021,7 @@ golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200923182605-d9f96fdee20d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201018230417-eeed37f84f13/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -1041,8 +1056,9 @@ golang.org/x/sys v0.0.0-20211124211545-fe61309f8881/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20211210111614-af8b64212486/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220209214540-3681064d5158 h1:rm+CHSpPEEW2IsXUib1ThaHIjuBVZjxNgSKmBLFfD4c=
golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64 h1:D1v9ucDTYBtbz5vNuBbAhIMAGhQhJ6Ym5ah3maMVNX4=
golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=
@ -1062,8 +1078,9 @@ golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxb
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1 h1:NusfzzA6yGQ+ua51ck7E3omNUX/JuqbFSaRGqU8CcLI=
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba h1:O8mE0/t419eoIwhTFpKVkHiTs/Igowgfkj25AcZrtiE=
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@ -1108,8 +1125,10 @@ golang.org/x/tools v0.0.0-20200331025713-a30bf2db82d4/go.mod h1:Sl4aGygMT6LrqrWc
golang.org/x/tools v0.0.0-20200501065659-ab2804fb9c9d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20200512131952-2bc93b1c0c88/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20200515010526-7d3b6ebf133d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20200522201501-cb1345f3a375/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20200618134242-20370b0cb4b2/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20200717024301-6ddee64345a6/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
@ -1244,8 +1263,9 @@ google.golang.org/genproto v0.0.0-20211221195035-429b39de9b1c/go.mod h1:5CzLGKJ6
google.golang.org/genproto v0.0.0-20220126215142-9970aeb2e350/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
google.golang.org/genproto v0.0.0-20220207164111-0872dc986b00/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
google.golang.org/genproto v0.0.0-20220218161850-94dd64e39d7c/go.mod h1:kGP+zUP2Ddo0ayMi4YuN7C3WZyJvGLZRh8Z5wnAqvEI=
google.golang.org/genproto v0.0.0-20220222213610-43724f9ea8cf h1:SVYXkUz2yZS9FWb2Gm8ivSlbNQzL2Z/NpPKE3RG2jWk=
google.golang.org/genproto v0.0.0-20220222213610-43724f9ea8cf/go.mod h1:kGP+zUP2Ddo0ayMi4YuN7C3WZyJvGLZRh8Z5wnAqvEI=
google.golang.org/genproto v0.0.0-20220401170504-314d38edb7de h1:9Ti5SG2U4cAcluryUo/sFay3TQKoxiFMfaT0pbizU7k=
google.golang.org/genproto v0.0.0-20220401170504-314d38edb7de/go.mod h1:8w6bsBMX6yCPbAVTeqQHvzxW0EIFigd5lZyahWgyfDo=
google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
@ -1279,8 +1299,9 @@ google.golang.org/grpc v1.39.1/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnD
google.golang.org/grpc v1.40.0/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34=
google.golang.org/grpc v1.40.1/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34=
google.golang.org/grpc v1.41.0/go.mod h1:U3l9uK9J0sini8mHphKoXyaqDA/8VyGnDee1zzIUK6k=
google.golang.org/grpc v1.44.0 h1:weqSxi/TMs1SqFRMHCtBgXRs8k3X39QIDEZ0pRcttUg=
google.golang.org/grpc v1.44.0/go.mod h1:k+4IHHFw41K8+bbowsex27ge2rCb65oeWqe4jJ590SU=
google.golang.org/grpc v1.45.0 h1:NEpgUqV3Z+ZjkqMsxMg11IaDrXY4RY6CQukSGK0uI1M=
google.golang.org/grpc v1.45.0/go.mod h1:lN7owxKUQEqMfSyQikvvk5tf/6zMPsrK+ONuO11+0rQ=
google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
@ -1292,10 +1313,12 @@ google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2
google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4=
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
google.golang.org/protobuf v1.25.1-0.20200805231151-a709e31e5d12/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.27.1 h1:SnqbnDw1V7RiZcXPx5MEeqPv2s79L9i7BJUlG/+RurQ=
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
@ -1320,11 +1343,13 @@ gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

@ -0,0 +1,301 @@
package policy
import (
"crypto/x509"
"fmt"
"net"
"net/url"
"golang.org/x/crypto/ssh"
"go.step.sm/crypto/x509util"
)
type NamePolicyReason int
const (
// NotAllowed results when an instance of NamePolicyEngine
// determines that there's a constraint which doesn't permit
// a DNS or another type of SAN to be signed (or otherwise used).
NotAllowed NamePolicyReason = iota + 1
// CannotParseDomain is returned when an error occurs
// when parsing the domain part of SAN or subject.
CannotParseDomain
// CannotParseRFC822Name is returned when an error
// occurs when parsing an email address.
CannotParseRFC822Name
// CannotMatch is the type of error returned when
// an error happens when matching SAN types.
CannotMatchNameToConstraint
)
type NameType string
const (
CNNameType NameType = "cn"
DNSNameType NameType = "dns"
IPNameType NameType = "ip"
EmailNameType NameType = "email"
URINameType NameType = "uri"
PrincipalNameType NameType = "principal"
)
type NamePolicyError struct {
Reason NamePolicyReason
NameType NameType
Name string
detail string
}
func (e *NamePolicyError) Error() string {
switch e.Reason {
case NotAllowed:
return fmt.Sprintf("%s name %q not allowed", e.NameType, e.Name)
case CannotParseDomain:
return fmt.Sprintf("cannot parse %s domain %q", e.NameType, e.Name)
case CannotParseRFC822Name:
return fmt.Sprintf("cannot parse %s rfc822Name %q", e.NameType, e.Name)
case CannotMatchNameToConstraint:
return fmt.Sprintf("error matching %s name %q to constraint", e.NameType, e.Name)
default:
return fmt.Sprintf("unknown error reason (%d): %s", e.Reason, e.detail)
}
}
func (e *NamePolicyError) Detail() string {
return e.detail
}
// NamePolicyEngine can be used to check that a CSR or Certificate meets all allowed and
// denied names before a CA creates and/or signs the Certificate.
// TODO(hs): the X509 RFC also defines name checks on directory name; support that?
// TODO(hs): implement Stringer interface: describe the contents of the NamePolicyEngine?
// TODO(hs): implement matching URI schemes, paths, etc; not just the domain part of URI domains
type NamePolicyEngine struct {
// verifySubjectCommonName is set when Subject Common Name must be verified
verifySubjectCommonName bool
// allowLiteralWildcardNames allows literal wildcard DNS domains
allowLiteralWildcardNames bool
// permitted and exluded constraints similar to x509 Name Constraints
permittedCommonNames []string
excludedCommonNames []string
permittedDNSDomains []string
excludedDNSDomains []string
permittedIPRanges []*net.IPNet
excludedIPRanges []*net.IPNet
permittedEmailAddresses []string
excludedEmailAddresses []string
permittedURIDomains []string
excludedURIDomains []string
permittedPrincipals []string
excludedPrincipals []string
// some internal counts for housekeeping
numberOfCommonNameConstraints int
numberOfDNSDomainConstraints int
numberOfIPRangeConstraints int
numberOfEmailAddressConstraints int
numberOfURIDomainConstraints int
numberOfPrincipalConstraints int
totalNumberOfPermittedConstraints int
totalNumberOfExcludedConstraints int
totalNumberOfConstraints int
}
// NewNamePolicyEngine creates a new NamePolicyEngine with NamePolicyOptions
func New(opts ...NamePolicyOption) (*NamePolicyEngine, error) {
e := &NamePolicyEngine{}
for _, option := range opts {
if err := option(e); err != nil {
return nil, err
}
}
e.permittedCommonNames = removeDuplicates(e.permittedCommonNames)
e.permittedDNSDomains = removeDuplicates(e.permittedDNSDomains)
e.permittedIPRanges = removeDuplicateIPNets(e.permittedIPRanges)
e.permittedEmailAddresses = removeDuplicates(e.permittedEmailAddresses)
e.permittedURIDomains = removeDuplicates(e.permittedURIDomains)
e.permittedPrincipals = removeDuplicates(e.permittedPrincipals)
e.excludedCommonNames = removeDuplicates(e.excludedCommonNames)
e.excludedDNSDomains = removeDuplicates(e.excludedDNSDomains)
e.excludedIPRanges = removeDuplicateIPNets(e.excludedIPRanges)
e.excludedEmailAddresses = removeDuplicates(e.excludedEmailAddresses)
e.excludedURIDomains = removeDuplicates(e.excludedURIDomains)
e.excludedPrincipals = removeDuplicates(e.excludedPrincipals)
e.numberOfCommonNameConstraints = len(e.permittedCommonNames) + len(e.excludedCommonNames)
e.numberOfDNSDomainConstraints = len(e.permittedDNSDomains) + len(e.excludedDNSDomains)
e.numberOfIPRangeConstraints = len(e.permittedIPRanges) + len(e.excludedIPRanges)
e.numberOfEmailAddressConstraints = len(e.permittedEmailAddresses) + len(e.excludedEmailAddresses)
e.numberOfURIDomainConstraints = len(e.permittedURIDomains) + len(e.excludedURIDomains)
e.numberOfPrincipalConstraints = len(e.permittedPrincipals) + len(e.excludedPrincipals)
e.totalNumberOfPermittedConstraints = len(e.permittedCommonNames) + len(e.permittedDNSDomains) +
len(e.permittedIPRanges) + len(e.permittedEmailAddresses) + len(e.permittedURIDomains) +
len(e.permittedPrincipals)
e.totalNumberOfExcludedConstraints = len(e.excludedCommonNames) + len(e.excludedDNSDomains) +
len(e.excludedIPRanges) + len(e.excludedEmailAddresses) + len(e.excludedURIDomains) +
len(e.excludedPrincipals)
e.totalNumberOfConstraints = e.totalNumberOfPermittedConstraints + e.totalNumberOfExcludedConstraints
return e, nil
}
// removeDuplicates returns a new slice of strings with
// duplicate values removed. It retains the order of elements
// in the source slice.
func removeDuplicates(items []string) (ret []string) {
// no need to remove dupes; return original
if len(items) <= 1 {
return items
}
keys := make(map[string]struct{}, len(items))
ret = make([]string, 0, len(items))
for _, item := range items {
if _, ok := keys[item]; ok {
continue
}
keys[item] = struct{}{}
ret = append(ret, item)
}
return
}
// removeDuplicateIPNets returns a new slice of net.IPNets with
// duplicate values removed. It retains the order of elements in
// the source slice. An IPNet is considered duplicate if its CIDR
// notation exists multiple times in the slice.
func removeDuplicateIPNets(items []*net.IPNet) (ret []*net.IPNet) {
// no need to remove dupes; return original
if len(items) <= 1 {
return items
}
keys := make(map[string]struct{}, len(items))
ret = make([]*net.IPNet, 0, len(items))
for _, item := range items {
key := item.String() // use CIDR notation as key
if _, ok := keys[key]; ok {
continue
}
keys[key] = struct{}{}
ret = append(ret, item)
}
// TODO(hs): implement filter of fully overlapping ranges,
// so that the smaller ones are automatically removed?
return
}
// IsX509CertificateAllowed verifies that all SANs in a Certificate are allowed.
func (e *NamePolicyEngine) IsX509CertificateAllowed(cert *x509.Certificate) error {
if err := e.validateNames(cert.DNSNames, cert.IPAddresses, cert.EmailAddresses, cert.URIs, []string{}); err != nil {
return err
}
if e.verifySubjectCommonName {
return e.validateCommonName(cert.Subject.CommonName)
}
return nil
}
// IsX509CertificateRequestAllowed verifies that all names in the CSR are allowed.
func (e *NamePolicyEngine) IsX509CertificateRequestAllowed(csr *x509.CertificateRequest) error {
if err := e.validateNames(csr.DNSNames, csr.IPAddresses, csr.EmailAddresses, csr.URIs, []string{}); err != nil {
return err
}
if e.verifySubjectCommonName {
return e.validateCommonName(csr.Subject.CommonName)
}
return nil
}
// AreSANSAllowed verifies that all names in the slice of SANs are allowed.
// The SANs are first split into DNS names, IPs, email addresses and URIs.
func (e *NamePolicyEngine) AreSANsAllowed(sans []string) error {
dnsNames, ips, emails, uris := x509util.SplitSANs(sans)
if err := e.validateNames(dnsNames, ips, emails, uris, []string{}); err != nil {
return err
}
return nil
}
// IsDNSAllowed verifies a single DNS domain is allowed.
func (e *NamePolicyEngine) IsDNSAllowed(dns string) error {
if err := e.validateNames([]string{dns}, []net.IP{}, []string{}, []*url.URL{}, []string{}); err != nil {
return err
}
return nil
}
// IsIPAllowed verifies a single IP domain is allowed.
func (e *NamePolicyEngine) IsIPAllowed(ip net.IP) error {
if err := e.validateNames([]string{}, []net.IP{ip}, []string{}, []*url.URL{}, []string{}); err != nil {
return err
}
return nil
}
// IsSSHCertificateAllowed verifies that all principals in an SSH certificate are allowed.
func (e *NamePolicyEngine) IsSSHCertificateAllowed(cert *ssh.Certificate) error {
dnsNames, ips, emails, principals, err := splitSSHPrincipals(cert)
if err != nil {
return err
}
if err := e.validateNames(dnsNames, ips, emails, []*url.URL{}, principals); err != nil {
return err
}
return nil
}
// splitPrincipals splits SSH certificate principals into DNS names, emails and usernames.
func splitSSHPrincipals(cert *ssh.Certificate) (dnsNames []string, ips []net.IP, emails, principals []string, err error) {
dnsNames = []string{}
ips = []net.IP{}
emails = []string{}
principals = []string{}
var uris []*url.URL
switch cert.CertType {
case ssh.HostCert:
dnsNames, ips, emails, uris = x509util.SplitSANs(cert.ValidPrincipals)
if len(uris) > 0 {
err = fmt.Errorf("URL principals %v not expected in SSH host certificate ", uris)
}
case ssh.UserCert:
// re-using SplitSANs results in anything that can't be parsed as an IP, URI or email
// to be considered a username principal. This allows usernames like h.slatman to be present
// in the SSH certificate. We're exluding URIs, because they can be confusing
// when used in a SSH user certificate.
principals, ips, emails, uris = x509util.SplitSANs(cert.ValidPrincipals)
if len(ips) > 0 {
err = fmt.Errorf("IP principals %v not expected in SSH user certificate ", ips)
}
if len(uris) > 0 {
err = fmt.Errorf("URL principals %v not expected in SSH user certificate ", uris)
}
default:
err = fmt.Errorf("unexpected SSH certificate type %d", cert.CertType)
}
return
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,386 @@
package policy
import (
"fmt"
"net"
"strings"
"golang.org/x/net/idna"
)
type NamePolicyOption func(e *NamePolicyEngine) error
// TODO: wrap (more) errors; and prove a set of known (exported) errors
func WithSubjectCommonNameVerification() NamePolicyOption {
return func(e *NamePolicyEngine) error {
e.verifySubjectCommonName = true
return nil
}
}
func WithAllowLiteralWildcardNames() NamePolicyOption {
return func(e *NamePolicyEngine) error {
e.allowLiteralWildcardNames = true
return nil
}
}
func WithPermittedCommonNames(commonNames ...string) NamePolicyOption {
return func(g *NamePolicyEngine) error {
normalizedCommonNames := make([]string, len(commonNames))
for i, commonName := range commonNames {
normalizedCommonName, err := normalizeAndValidateCommonName(commonName)
if err != nil {
return fmt.Errorf("cannot parse permitted common name constraint %q: %w", commonName, err)
}
normalizedCommonNames[i] = normalizedCommonName
}
g.permittedCommonNames = normalizedCommonNames
return nil
}
}
func WithExcludedCommonNames(commonNames ...string) NamePolicyOption {
return func(g *NamePolicyEngine) error {
normalizedCommonNames := make([]string, len(commonNames))
for i, commonName := range commonNames {
normalizedCommonName, err := normalizeAndValidateCommonName(commonName)
if err != nil {
return fmt.Errorf("cannot parse excluded common name constraint %q: %w", commonName, err)
}
normalizedCommonNames[i] = normalizedCommonName
}
g.excludedCommonNames = normalizedCommonNames
return nil
}
}
func WithPermittedDNSDomains(domains ...string) NamePolicyOption {
return func(e *NamePolicyEngine) error {
normalizedDomains := make([]string, len(domains))
for i, domain := range domains {
normalizedDomain, err := normalizeAndValidateDNSDomainConstraint(domain)
if err != nil {
return fmt.Errorf("cannot parse permitted domain constraint %q: %w", domain, err)
}
normalizedDomains[i] = normalizedDomain
}
e.permittedDNSDomains = normalizedDomains
return nil
}
}
func WithExcludedDNSDomains(domains ...string) NamePolicyOption {
return func(e *NamePolicyEngine) error {
normalizedDomains := make([]string, len(domains))
for i, domain := range domains {
normalizedDomain, err := normalizeAndValidateDNSDomainConstraint(domain)
if err != nil {
return fmt.Errorf("cannot parse excluded domain constraint %q: %w", domain, err)
}
normalizedDomains[i] = normalizedDomain
}
e.excludedDNSDomains = normalizedDomains
return nil
}
}
func WithPermittedIPRanges(ipRanges ...*net.IPNet) NamePolicyOption {
return func(e *NamePolicyEngine) error {
e.permittedIPRanges = ipRanges
return nil
}
}
func WithPermittedCIDRs(cidrs ...string) NamePolicyOption {
return func(e *NamePolicyEngine) error {
networks := make([]*net.IPNet, len(cidrs))
for i, cidr := range cidrs {
_, nw, err := net.ParseCIDR(cidr)
if err != nil {
return fmt.Errorf("cannot parse permitted CIDR constraint %q", cidr)
}
networks[i] = nw
}
e.permittedIPRanges = networks
return nil
}
}
func WithExcludedCIDRs(cidrs ...string) NamePolicyOption {
return func(e *NamePolicyEngine) error {
networks := make([]*net.IPNet, len(cidrs))
for i, cidr := range cidrs {
_, nw, err := net.ParseCIDR(cidr)
if err != nil {
return fmt.Errorf("cannot parse excluded CIDR constraint %q", cidr)
}
networks[i] = nw
}
e.excludedIPRanges = networks
return nil
}
}
func WithPermittedIPsOrCIDRs(ipsOrCIDRs ...string) NamePolicyOption {
return func(e *NamePolicyEngine) error {
networks := make([]*net.IPNet, len(ipsOrCIDRs))
for i, ipOrCIDR := range ipsOrCIDRs {
_, nw, err := net.ParseCIDR(ipOrCIDR)
if err == nil {
networks[i] = nw
} else if ip := net.ParseIP(ipOrCIDR); ip != nil {
networks[i] = networkFor(ip)
} else {
return fmt.Errorf("cannot parse permitted constraint %q as IP nor CIDR", ipOrCIDR)
}
}
e.permittedIPRanges = networks
return nil
}
}
func WithExcludedIPsOrCIDRs(ipsOrCIDRs ...string) NamePolicyOption {
return func(e *NamePolicyEngine) error {
networks := make([]*net.IPNet, len(ipsOrCIDRs))
for i, ipOrCIDR := range ipsOrCIDRs {
_, nw, err := net.ParseCIDR(ipOrCIDR)
if err == nil {
networks[i] = nw
} else if ip := net.ParseIP(ipOrCIDR); ip != nil {
networks[i] = networkFor(ip)
} else {
return fmt.Errorf("cannot parse excluded constraint %q as IP nor CIDR", ipOrCIDR)
}
}
e.excludedIPRanges = networks
return nil
}
}
func WithExcludedIPRanges(ipRanges ...*net.IPNet) NamePolicyOption {
return func(e *NamePolicyEngine) error {
e.excludedIPRanges = ipRanges
return nil
}
}
func WithPermittedEmailAddresses(emailAddresses ...string) NamePolicyOption {
return func(e *NamePolicyEngine) error {
normalizedEmailAddresses := make([]string, len(emailAddresses))
for i, email := range emailAddresses {
normalizedEmailAddress, err := normalizeAndValidateEmailConstraint(email)
if err != nil {
return fmt.Errorf("cannot parse permitted email constraint %q: %w", email, err)
}
normalizedEmailAddresses[i] = normalizedEmailAddress
}
e.permittedEmailAddresses = normalizedEmailAddresses
return nil
}
}
func WithExcludedEmailAddresses(emailAddresses ...string) NamePolicyOption {
return func(e *NamePolicyEngine) error {
normalizedEmailAddresses := make([]string, len(emailAddresses))
for i, email := range emailAddresses {
normalizedEmailAddress, err := normalizeAndValidateEmailConstraint(email)
if err != nil {
return fmt.Errorf("cannot parse excluded email constraint %q: %w", email, err)
}
normalizedEmailAddresses[i] = normalizedEmailAddress
}
e.excludedEmailAddresses = normalizedEmailAddresses
return nil
}
}
func WithPermittedURIDomains(uriDomains ...string) NamePolicyOption {
return func(e *NamePolicyEngine) error {
normalizedURIDomains := make([]string, len(uriDomains))
for i, domain := range uriDomains {
normalizedURIDomain, err := normalizeAndValidateURIDomainConstraint(domain)
if err != nil {
return fmt.Errorf("cannot parse permitted URI domain constraint %q: %w", domain, err)
}
normalizedURIDomains[i] = normalizedURIDomain
}
e.permittedURIDomains = normalizedURIDomains
return nil
}
}
func WithExcludedURIDomains(domains ...string) NamePolicyOption {
return func(e *NamePolicyEngine) error {
normalizedURIDomains := make([]string, len(domains))
for i, domain := range domains {
normalizedURIDomain, err := normalizeAndValidateURIDomainConstraint(domain)
if err != nil {
return fmt.Errorf("cannot parse excluded URI domain constraint %q: %w", domain, err)
}
normalizedURIDomains[i] = normalizedURIDomain
}
e.excludedURIDomains = normalizedURIDomains
return nil
}
}
func WithPermittedPrincipals(principals ...string) NamePolicyOption {
return func(g *NamePolicyEngine) error {
g.permittedPrincipals = principals
return nil
}
}
func WithExcludedPrincipals(principals ...string) NamePolicyOption {
return func(g *NamePolicyEngine) error {
g.excludedPrincipals = principals
return nil
}
}
func networkFor(ip net.IP) *net.IPNet {
var mask net.IPMask
if !isIPv4(ip) {
mask = net.CIDRMask(128, 128)
} else {
mask = net.CIDRMask(32, 32)
}
nw := &net.IPNet{
IP: ip,
Mask: mask,
}
return nw
}
func isIPv4(ip net.IP) bool {
return ip.To4() != nil
}
func normalizeAndValidateCommonName(constraint string) (string, error) {
normalizedConstraint := strings.ToLower(strings.TrimSpace(constraint))
if normalizedConstraint == "" {
return "", fmt.Errorf("contraint %q can not be empty or white space string", constraint)
}
if normalizedConstraint == "*" {
return "", fmt.Errorf("wildcard constraint %q is not supported", constraint)
}
return normalizedConstraint, nil
}
func normalizeAndValidateDNSDomainConstraint(constraint string) (string, error) {
normalizedConstraint := strings.ToLower(strings.TrimSpace(constraint))
if normalizedConstraint == "" {
return "", fmt.Errorf("contraint %q can not be empty or white space string", constraint)
}
if strings.Contains(normalizedConstraint, "..") {
return "", fmt.Errorf("domain constraint %q cannot have empty labels", constraint)
}
if strings.HasPrefix(normalizedConstraint, ".") {
return "", fmt.Errorf("domain constraint %q with wildcard should start with *", constraint)
}
if strings.LastIndex(normalizedConstraint, "*") > 0 {
return "", fmt.Errorf("domain constraint %q can only have wildcard as starting character", constraint)
}
if len(normalizedConstraint) >= 2 && normalizedConstraint[0] == '*' && normalizedConstraint[1] != '.' {
return "", fmt.Errorf("wildcard character in domain constraint %q can only be used to match (full) labels", constraint)
}
if strings.HasPrefix(normalizedConstraint, "*.") {
normalizedConstraint = normalizedConstraint[1:] // cut off wildcard character; keep the period
}
normalizedConstraint, err := idna.Lookup.ToASCII(normalizedConstraint)
if err != nil {
return "", fmt.Errorf("domain constraint %q can not be converted to ASCII: %w", constraint, err)
}
if _, ok := domainToReverseLabels(normalizedConstraint); !ok {
return "", fmt.Errorf("cannot parse domain constraint %q", constraint)
}
return normalizedConstraint, nil
}
func normalizeAndValidateEmailConstraint(constraint string) (string, error) {
normalizedConstraint := strings.ToLower(strings.TrimSpace(constraint))
if normalizedConstraint == "" {
return "", fmt.Errorf("email contraint %q can not be empty or white space string", constraint)
}
if strings.Contains(normalizedConstraint, "*") {
return "", fmt.Errorf("email constraint %q cannot contain asterisk wildcard", constraint)
}
if strings.Count(normalizedConstraint, "@") > 1 {
return "", fmt.Errorf("email constraint %q contains too many @ characters", constraint)
}
if normalizedConstraint[0] == '@' {
normalizedConstraint = normalizedConstraint[1:] // remove the leading @ as wildcard for emails
}
if normalizedConstraint[0] == '.' {
return "", fmt.Errorf("email constraint %q cannot start with period", constraint)
}
if strings.Contains(normalizedConstraint, "@") {
mailbox, ok := parseRFC2821Mailbox(normalizedConstraint)
if !ok {
return "", fmt.Errorf("cannot parse email constraint %q as RFC 2821 mailbox", constraint)
}
// According to RFC 5280, section 7.5, emails are considered to match if the local part is
// an exact match and the host (domain) part matches the ASCII representation (case-insensitive):
// https://datatracker.ietf.org/doc/html/rfc5280#section-7.5
domainASCII, err := idna.Lookup.ToASCII(mailbox.domain)
if err != nil {
return "", fmt.Errorf("email constraint %q domain part %q cannot be converted to ASCII: %w", constraint, mailbox.domain, err)
}
normalizedConstraint = mailbox.local + "@" + domainASCII
} else {
var err error
normalizedConstraint, err = idna.Lookup.ToASCII(normalizedConstraint)
if err != nil {
return "", fmt.Errorf("email constraint %q cannot be converted to ASCII: %w", constraint, err)
}
}
if _, ok := domainToReverseLabels(normalizedConstraint); !ok {
return "", fmt.Errorf("cannot parse email domain constraint %q", constraint)
}
return normalizedConstraint, nil
}
func normalizeAndValidateURIDomainConstraint(constraint string) (string, error) {
normalizedConstraint := strings.ToLower(strings.TrimSpace(constraint))
if normalizedConstraint == "" {
return "", fmt.Errorf("URI domain contraint %q cannot be empty or white space string", constraint)
}
if strings.Contains(normalizedConstraint, "://") {
return "", fmt.Errorf("URI domain constraint %q contains scheme (not supported yet)", constraint)
}
if strings.Contains(normalizedConstraint, "..") {
return "", fmt.Errorf("URI domain constraint %q cannot have empty labels", constraint)
}
if strings.HasPrefix(normalizedConstraint, ".") {
return "", fmt.Errorf("URI domain constraint %q with wildcard should start with *", constraint)
}
if strings.LastIndex(normalizedConstraint, "*") > 0 {
return "", fmt.Errorf("URI domain constraint %q can only have wildcard as starting character", constraint)
}
if strings.HasPrefix(normalizedConstraint, "*.") {
normalizedConstraint = normalizedConstraint[1:] // cut off wildcard character; keep the period
}
// we're being strict with square brackets in domains; we don't allow them, no matter what
if strings.Contains(normalizedConstraint, "[") || strings.Contains(normalizedConstraint, "]") {
return "", fmt.Errorf("URI domain constraint %q contains invalid square brackets", constraint)
}
if _, _, err := net.SplitHostPort(normalizedConstraint); err == nil {
// a successful split (likely) with host and port; we don't currently allow ports in the config
return "", fmt.Errorf("URI domain constraint %q cannot contain port", constraint)
}
// check if the host part of the URI domain constraint is an IP
if net.ParseIP(normalizedConstraint) != nil {
return "", fmt.Errorf("URI domain constraint %q cannot be an IP", constraint)
}
normalizedConstraint, err := idna.Lookup.ToASCII(normalizedConstraint)
if err != nil {
return "", fmt.Errorf("URI domain constraint %q cannot be converted to ASCII: %w", constraint, err)
}
_, ok := domainToReverseLabels(normalizedConstraint)
if !ok {
return "", fmt.Errorf("cannot parse URI domain constraint %q", constraint)
}
return normalizedConstraint, nil
}

@ -0,0 +1,125 @@
//go:build !go1.18
// +build !go1.18
package policy
import "testing"
func Test_normalizeAndValidateURIDomainConstraint(t *testing.T) {
tests := []struct {
name string
constraint string
want string
wantErr bool
}{
{
name: "fail/empty-constraint",
constraint: "",
want: "",
wantErr: true,
},
{
name: "fail/scheme-https",
constraint: `https://*.local`,
want: "",
wantErr: true,
},
{
name: "fail/too-many-asterisks",
constraint: "**.local",
want: "",
wantErr: true,
},
{
name: "fail/empty-label",
constraint: "..local",
want: "",
wantErr: true,
},
{
name: "fail/empty-reverse",
constraint: ".",
want: "",
wantErr: true,
},
{
name: "fail/no-asterisk",
constraint: ".example.com",
want: "",
wantErr: true,
},
{
name: "fail/domain-with-port",
constraint: "host.local:8443",
want: "",
wantErr: true,
},
{
name: "fail/ipv4",
constraint: "127.0.0.1",
want: "",
wantErr: true,
},
{
name: "fail/ipv6-brackets",
constraint: "[::1]",
want: "",
wantErr: true,
},
{
name: "fail/ipv6-no-brackets",
constraint: "::1",
want: "",
wantErr: true,
},
{
name: "fail/ipv6-no-brackets",
constraint: "[::1",
want: "",
wantErr: true,
},
{
name: "fail/idna-internationalized-domain-name-lookup",
constraint: `\00local`,
want: "",
wantErr: true,
},
{
name: "ok/wildcard",
constraint: "*.local",
want: ".local",
wantErr: false,
},
{
name: "ok/specific-domain",
constraint: "example.local",
want: "example.local",
wantErr: false,
},
{
name: "ok/idna-internationalized-domain-name-lookup",
constraint: `*.bücher.example.com`,
want: ".xn--bcher-kva.example.com",
wantErr: false,
},
{
// IDNA2003 vs. 2008 deviation: https://unicode.org/reports/tr46/#Deviations results
// in a difference between Go 1.18 and lower versions. Go 1.18 expects ".xn--fa-hia.de"; not .fass.de.
name: "ok/idna-internationalized-domain-name-lookup-deviation",
constraint: `*.faß.de`,
want: ".fass.de",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := normalizeAndValidateURIDomainConstraint(tt.constraint)
if (err != nil) != tt.wantErr {
t.Errorf("normalizeAndValidateURIDomainConstraint() error = %v, wantErr %v", err, tt.wantErr)
}
if got != tt.want {
t.Errorf("normalizeAndValidateURIDomainConstraint() = %v, want %v", got, tt.want)
}
})
}
}

@ -0,0 +1,125 @@
//go:build go1.18
// +build go1.18
package policy
import "testing"
func Test_normalizeAndValidateURIDomainConstraint(t *testing.T) {
tests := []struct {
name string
constraint string
want string
wantErr bool
}{
{
name: "fail/empty-constraint",
constraint: "",
want: "",
wantErr: true,
},
{
name: "fail/scheme-https",
constraint: `https://*.local`,
want: "",
wantErr: true,
},
{
name: "fail/too-many-asterisks",
constraint: "**.local",
want: "",
wantErr: true,
},
{
name: "fail/empty-label",
constraint: "..local",
want: "",
wantErr: true,
},
{
name: "fail/empty-reverse",
constraint: ".",
want: "",
wantErr: true,
},
{
name: "fail/domain-with-port",
constraint: "host.local:8443",
want: "",
wantErr: true,
},
{
name: "fail/no-asterisk",
constraint: ".example.com",
want: "",
wantErr: true,
},
{
name: "fail/ipv4",
constraint: "127.0.0.1",
want: "",
wantErr: true,
},
{
name: "fail/ipv6-brackets",
constraint: "[::1]",
want: "",
wantErr: true,
},
{
name: "fail/ipv6-no-brackets",
constraint: "::1",
want: "",
wantErr: true,
},
{
name: "fail/ipv6-no-brackets",
constraint: "[::1",
want: "",
wantErr: true,
},
{
name: "fail/idna-internationalized-domain-name-lookup",
constraint: `\00local`,
want: "",
wantErr: true,
},
{
name: "ok/wildcard",
constraint: "*.local",
want: ".local",
wantErr: false,
},
{
name: "ok/specific-domain",
constraint: "example.local",
want: "example.local",
wantErr: false,
},
{
name: "ok/idna-internationalized-domain-name-lookup",
constraint: `*.bücher.example.com`,
want: ".xn--bcher-kva.example.com",
wantErr: false,
},
{
// IDNA2003 vs. 2008 deviation: https://unicode.org/reports/tr46/#Deviations results
// in a difference between Go 1.18 and lower versions. Go 1.18 expects ".xn--fa-hia.de"; not .fass.de.
name: "ok/idna-internationalized-domain-name-lookup-deviation",
constraint: `*.faß.de`,
want: ".xn--fa-hia.de",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := normalizeAndValidateURIDomainConstraint(tt.constraint)
if (err != nil) != tt.wantErr {
t.Errorf("normalizeAndValidateURIDomainConstraint() error = %v, wantErr %v", err, tt.wantErr)
}
if got != tt.want {
t.Errorf("normalizeAndValidateURIDomainConstraint() = %v, want %v", got, tt.want)
}
})
}
}

@ -0,0 +1,660 @@
package policy
import (
"net"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
)
func Test_normalizeAndValidateCommonName(t *testing.T) {
tests := []struct {
name string
constraint string
want string
wantErr bool
}{
{
name: "fail/empty-constraint",
constraint: "",
want: "",
wantErr: true,
},
{
name: "fail/wildcard",
constraint: "*",
want: "",
wantErr: true,
},
{
name: "ok",
constraint: "step",
want: "step",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := normalizeAndValidateCommonName(tt.constraint)
if (err != nil) != tt.wantErr {
t.Errorf("normalizeAndValidateCommonName() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("normalizeAndValidateCommonName() = %v, want %v", got, tt.want)
}
})
}
}
func Test_normalizeAndValidateDNSDomainConstraint(t *testing.T) {
tests := []struct {
name string
constraint string
want string
wantErr bool
}{
{
name: "fail/empty-constraint",
constraint: "",
want: "",
wantErr: true,
},
{
name: "fail/wildcard-partial-label",
constraint: "*xxxx.local",
want: "",
wantErr: true,
},
{
name: "fail/wildcard-in-the-middle",
constraint: "x.*.local",
want: "",
wantErr: true,
},
{
name: "fail/empty-label",
constraint: "..local",
want: "",
wantErr: true,
},
{
name: "fail/empty-reverse",
constraint: ".",
want: "",
wantErr: true,
},
{
name: "fail/no-asterisk",
constraint: ".example.com",
want: "",
wantErr: true,
},
{
name: "fail/idna-internationalized-domain-name-lookup",
constraint: `\00.local`, // invalid IDNA ASCII character
want: "",
wantErr: true,
},
{
name: "ok/wildcard",
constraint: "*.local",
want: ".local",
wantErr: false,
},
{
name: "ok/specific-domain",
constraint: "example.local",
want: "example.local",
wantErr: false,
},
{
name: "ok/idna-internationalized-domain-name-punycode",
constraint: "*.xn--fsq.jp", // Example value from https://www.w3.org/International/articles/idn-and-iri/
want: ".xn--fsq.jp",
wantErr: false,
},
{
name: "ok/idna-internationalized-domain-name-lookup-transformed",
constraint: "*.例.jp", // Example value from https://www.w3.org/International/articles/idn-and-iri/
want: ".xn--fsq.jp",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := normalizeAndValidateDNSDomainConstraint(tt.constraint)
if (err != nil) != tt.wantErr {
t.Errorf("normalizeAndValidateDNSDomainConstraint() error = %v, wantErr %v", err, tt.wantErr)
}
if got != tt.want {
t.Errorf("normalizeAndValidateDNSDomainConstraint() = %v, want %v", got, tt.want)
}
})
}
}
func Test_normalizeAndValidateEmailConstraint(t *testing.T) {
tests := []struct {
name string
constraint string
want string
wantErr bool
}{
{
name: "fail/empty-constraint",
constraint: "",
want: "",
wantErr: true,
},
{
name: "fail/asterisk",
constraint: "*.local",
want: "",
wantErr: true,
},
{
name: "fail/period",
constraint: ".local",
want: "",
wantErr: true,
},
{
name: "fail/@period",
constraint: "@.local",
want: "",
wantErr: true,
},
{
name: "fail/too-many-@s",
constraint: "@local@example.com",
want: "",
wantErr: true,
},
{
name: "fail/parse-mailbox",
constraint: "mail@example.com" + string(byte(0)),
want: "",
wantErr: true,
},
{
name: "fail/idna-internationalized-domain",
constraint: "mail@xn--bla.local",
want: "",
wantErr: true,
},
{
name: "fail/idna-internationalized-domain-name-lookup",
constraint: `\00local`,
want: "",
wantErr: true,
},
{
name: "fail/parse-domain",
constraint: "x..example.com",
want: "",
wantErr: true,
},
{
name: "ok/wildcard",
constraint: "@local",
want: "local",
wantErr: false,
},
{
name: "ok/specific-mail",
constraint: "mail@local",
want: "mail@local",
wantErr: false,
},
// TODO(hs): fix the below; doesn't get past parseRFC2821Mailbox; I think it should be allowed.
// {
// name: "ok/idna-internationalized-local",
// constraint: `bücher@local`,
// want: "bücher@local",
// wantErr: false,
// },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := normalizeAndValidateEmailConstraint(tt.constraint)
if (err != nil) != tt.wantErr {
t.Errorf("normalizeAndValidateEmailConstraint() error = %v, wantErr %v", err, tt.wantErr)
}
if got != tt.want {
t.Errorf("normalizeAndValidateEmailConstraint() = %v, want %v", got, tt.want)
}
})
}
}
func TestNew(t *testing.T) {
type test struct {
options []NamePolicyOption
want *NamePolicyEngine
wantErr bool
}
var tests = map[string]func(t *testing.T) test{
"fail/with-permitted-common-name": func(t *testing.T) test {
return test{
options: []NamePolicyOption{
WithPermittedCommonNames("*"),
},
want: nil,
wantErr: true,
}
},
"fail/with-excluded-common-name": func(t *testing.T) test {
return test{
options: []NamePolicyOption{
WithExcludedCommonNames(""),
},
want: nil,
wantErr: true,
}
},
"fail/with-permitted-dns-domains": func(t *testing.T) test {
return test{
options: []NamePolicyOption{
WithPermittedDNSDomains("**.local"),
},
want: nil,
wantErr: true,
}
},
"fail/with-excluded-dns-domains": func(t *testing.T) test {
return test{
options: []NamePolicyOption{
WithExcludedDNSDomains("**.local"),
},
want: nil,
wantErr: true,
}
},
"fail/with-permitted-cidrs": func(t *testing.T) test {
return test{
options: []NamePolicyOption{
WithPermittedCIDRs("127.0.0.1//24"),
},
want: nil,
wantErr: true,
}
},
"fail/with-excluded-cidrs": func(t *testing.T) test {
return test{
options: []NamePolicyOption{
WithExcludedCIDRs("127.0.0.1//24"),
},
want: nil,
wantErr: true,
}
},
"fail/with-permitted-ipsOrCIDRs-cidr": func(t *testing.T) test {
return test{
options: []NamePolicyOption{
WithPermittedIPsOrCIDRs("127.0.0.1//24"),
},
want: nil,
wantErr: true,
}
},
"fail/with-permitted-ipsOrCIDRs-ip": func(t *testing.T) test {
return test{
options: []NamePolicyOption{
WithPermittedIPsOrCIDRs("127.0.0:1"),
},
want: nil,
wantErr: true,
}
},
"fail/with-excluded-ipsOrCIDRs-cidr": func(t *testing.T) test {
return test{
options: []NamePolicyOption{
WithExcludedIPsOrCIDRs("127.0.0.1//24"),
},
want: nil,
wantErr: true,
}
},
"fail/with-excluded-ipsOrCIDRs-ip": func(t *testing.T) test {
return test{
options: []NamePolicyOption{
WithExcludedIPsOrCIDRs("127.0.0:1"),
},
want: nil,
wantErr: true,
}
},
"fail/with-permitted-emails": func(t *testing.T) test {
return test{
options: []NamePolicyOption{
WithPermittedEmailAddresses("*.local"),
},
want: nil,
wantErr: true,
}
},
"fail/with-excluded-emails": func(t *testing.T) test {
return test{
options: []NamePolicyOption{
WithExcludedEmailAddresses("*.local"),
},
want: nil,
wantErr: true,
}
},
"fail/with-permitted-uris": func(t *testing.T) test {
return test{
options: []NamePolicyOption{
WithPermittedURIDomains("**.local"),
},
want: nil,
wantErr: true,
}
},
"fail/with-excluded-uris": func(t *testing.T) test {
return test{
options: []NamePolicyOption{
WithExcludedURIDomains("**.local"),
},
want: nil,
wantErr: true,
}
},
"ok/default": func(t *testing.T) test {
return test{
options: []NamePolicyOption{},
want: &NamePolicyEngine{},
wantErr: false,
}
},
"ok/subject-verification": func(t *testing.T) test {
options := []NamePolicyOption{
WithSubjectCommonNameVerification(),
}
return test{
options: options,
want: &NamePolicyEngine{
verifySubjectCommonName: true,
},
wantErr: false,
}
},
"ok/literal-wildcards": func(t *testing.T) test {
options := []NamePolicyOption{
WithAllowLiteralWildcardNames(),
}
return test{
options: options,
want: &NamePolicyEngine{
allowLiteralWildcardNames: true,
},
wantErr: false,
}
},
"ok/with-permitted-dns-wildcard-domains": func(t *testing.T) test {
options := []NamePolicyOption{
WithPermittedDNSDomains("*.local", "*.example.com"),
}
return test{
options: options,
want: &NamePolicyEngine{
permittedDNSDomains: []string{".local", ".example.com"},
numberOfDNSDomainConstraints: 2,
totalNumberOfPermittedConstraints: 2,
totalNumberOfConstraints: 2,
},
wantErr: false,
}
},
"ok/with-excluded-dns-domains": func(t *testing.T) test {
options := []NamePolicyOption{
WithExcludedDNSDomains("*.local", "*.example.com"),
}
return test{
options: options,
want: &NamePolicyEngine{
excludedDNSDomains: []string{".local", ".example.com"},
numberOfDNSDomainConstraints: 2,
totalNumberOfExcludedConstraints: 2,
totalNumberOfConstraints: 2,
},
wantErr: false,
}
},
"ok/with-permitted-ip-ranges": func(t *testing.T) test {
_, nw1, err := net.ParseCIDR("127.0.0.1/24")
assert.NoError(t, err)
_, nw2, err := net.ParseCIDR("192.168.0.1/24")
assert.NoError(t, err)
options := []NamePolicyOption{
WithPermittedIPRanges(nw1, nw2),
}
return test{
options: options,
want: &NamePolicyEngine{
permittedIPRanges: []*net.IPNet{
nw1, nw2,
},
numberOfIPRangeConstraints: 2,
totalNumberOfPermittedConstraints: 2,
totalNumberOfConstraints: 2,
},
wantErr: false,
}
},
"ok/with-excluded-ip-ranges": func(t *testing.T) test {
_, nw1, err := net.ParseCIDR("127.0.0.1/24")
assert.NoError(t, err)
_, nw2, err := net.ParseCIDR("192.168.0.1/24")
assert.NoError(t, err)
options := []NamePolicyOption{
WithExcludedIPRanges(nw1, nw2),
}
return test{
options: options,
want: &NamePolicyEngine{
excludedIPRanges: []*net.IPNet{
nw1, nw2,
},
numberOfIPRangeConstraints: 2,
totalNumberOfExcludedConstraints: 2,
totalNumberOfConstraints: 2,
},
wantErr: false,
}
},
"ok/with-permitted-cidrs": func(t *testing.T) test {
_, nw1, err := net.ParseCIDR("127.0.0.1/24")
assert.NoError(t, err)
_, nw2, err := net.ParseCIDR("192.168.0.1/24")
assert.NoError(t, err)
options := []NamePolicyOption{
WithPermittedCIDRs("127.0.0.1/24", "192.168.0.1/24"),
}
return test{
options: options,
want: &NamePolicyEngine{
permittedIPRanges: []*net.IPNet{
nw1, nw2,
},
numberOfIPRangeConstraints: 2,
totalNumberOfPermittedConstraints: 2,
totalNumberOfConstraints: 2,
},
wantErr: false,
}
},
"ok/with-excluded-cidrs": func(t *testing.T) test {
_, nw1, err := net.ParseCIDR("127.0.0.1/24")
assert.NoError(t, err)
_, nw2, err := net.ParseCIDR("192.168.0.1/24")
assert.NoError(t, err)
options := []NamePolicyOption{
WithExcludedCIDRs("127.0.0.1/24", "192.168.0.1/24"),
}
return test{
options: options,
want: &NamePolicyEngine{
excludedIPRanges: []*net.IPNet{
nw1, nw2,
},
numberOfIPRangeConstraints: 2,
totalNumberOfExcludedConstraints: 2,
totalNumberOfConstraints: 2,
},
wantErr: false,
}
},
"ok/with-permitted-ipsOrCIDRs-cidr": func(t *testing.T) test {
_, nw1, err := net.ParseCIDR("127.0.0.1/24")
assert.NoError(t, err)
_, nw2, err := net.ParseCIDR("192.168.0.31/32")
assert.NoError(t, err)
_, nw3, err := net.ParseCIDR("2001:0db8:85a3:0000:0000:8a2e:0370:7334/128")
assert.NoError(t, err)
options := []NamePolicyOption{
WithPermittedIPsOrCIDRs("127.0.0.1/24", "192.168.0.31", "2001:0db8:85a3:0000:0000:8a2e:0370:7334"),
}
return test{
options: options,
want: &NamePolicyEngine{
permittedIPRanges: []*net.IPNet{
nw1, nw2, nw3,
},
numberOfIPRangeConstraints: 3,
totalNumberOfPermittedConstraints: 3,
totalNumberOfConstraints: 3,
},
wantErr: false,
}
},
"ok/with-excluded-ipsOrCIDRs-cidr": func(t *testing.T) test {
_, nw1, err := net.ParseCIDR("127.0.0.1/24")
assert.NoError(t, err)
_, nw2, err := net.ParseCIDR("192.168.0.31/32")
assert.NoError(t, err)
_, nw3, err := net.ParseCIDR("2001:0db8:85a3:0000:0000:8a2e:0370:7334/128")
assert.NoError(t, err)
options := []NamePolicyOption{
WithExcludedIPsOrCIDRs("127.0.0.1/24", "192.168.0.31", "2001:0db8:85a3:0000:0000:8a2e:0370:7334"),
}
return test{
options: options,
want: &NamePolicyEngine{
excludedIPRanges: []*net.IPNet{
nw1, nw2, nw3,
},
numberOfIPRangeConstraints: 3,
totalNumberOfExcludedConstraints: 3,
totalNumberOfConstraints: 3,
},
wantErr: false,
}
},
"ok/with-permitted-emails": func(t *testing.T) test {
options := []NamePolicyOption{
WithPermittedEmailAddresses("mail@local", "@example.com"),
}
return test{
options: options,
want: &NamePolicyEngine{
permittedEmailAddresses: []string{"mail@local", "example.com"},
numberOfEmailAddressConstraints: 2,
totalNumberOfPermittedConstraints: 2,
totalNumberOfConstraints: 2,
},
wantErr: false,
}
},
"ok/with-excluded-emails": func(t *testing.T) test {
options := []NamePolicyOption{
WithExcludedEmailAddresses("mail@local", "@example.com"),
}
return test{
options: options,
want: &NamePolicyEngine{
excludedEmailAddresses: []string{"mail@local", "example.com"},
numberOfEmailAddressConstraints: 2,
totalNumberOfExcludedConstraints: 2,
totalNumberOfConstraints: 2,
},
wantErr: false,
}
},
"ok/with-permitted-uris": func(t *testing.T) test {
options := []NamePolicyOption{
WithPermittedURIDomains("host.local", "*.example.com"),
}
return test{
options: options,
want: &NamePolicyEngine{
permittedURIDomains: []string{"host.local", ".example.com"},
numberOfURIDomainConstraints: 2,
totalNumberOfPermittedConstraints: 2,
totalNumberOfConstraints: 2,
},
wantErr: false,
}
},
"ok/with-excluded-uris": func(t *testing.T) test {
options := []NamePolicyOption{
WithExcludedURIDomains("host.local", "*.example.com"),
}
return test{
options: options,
want: &NamePolicyEngine{
excludedURIDomains: []string{"host.local", ".example.com"},
numberOfURIDomainConstraints: 2,
totalNumberOfExcludedConstraints: 2,
totalNumberOfConstraints: 2,
},
wantErr: false,
}
},
"ok/with-permitted-principals": func(t *testing.T) test {
options := []NamePolicyOption{
WithPermittedPrincipals("root", "ops"),
}
return test{
options: options,
want: &NamePolicyEngine{
permittedPrincipals: []string{"root", "ops"},
numberOfPrincipalConstraints: 2,
totalNumberOfPermittedConstraints: 2,
totalNumberOfConstraints: 2,
},
wantErr: false,
}
},
"ok/with-excluded-principals": func(t *testing.T) test {
options := []NamePolicyOption{
WithExcludedPrincipals("root", "ops"),
}
return test{
options: options,
want: &NamePolicyEngine{
excludedPrincipals: []string{"root", "ops"},
numberOfPrincipalConstraints: 2,
totalNumberOfExcludedConstraints: 2,
totalNumberOfConstraints: 2,
},
wantErr: false,
}
},
}
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
got, err := New(tc.options...)
if (err != nil) != tc.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tc.wantErr)
return
}
if !cmp.Equal(tc.want, got, cmp.AllowUnexported(NamePolicyEngine{})) {
t.Errorf("New() diff =\n %s", cmp.Diff(tc.want, got, cmp.AllowUnexported(NamePolicyEngine{})))
}
})
}
}

@ -0,0 +1,9 @@
package policy
import (
"golang.org/x/crypto/ssh"
)
type SSHNamePolicyEngine interface {
IsSSHCertificateAllowed(cert *ssh.Certificate) error
}

@ -0,0 +1,647 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//
// The code in this file is an adapted version of the code in
// https://cs.opensource.google/go/go/+/refs/tags/go1.17.5:src/crypto/x509/verify.go
package policy
import (
"bytes"
"fmt"
"net"
"net/url"
"reflect"
"strings"
"golang.org/x/net/idna"
"go.step.sm/crypto/x509util"
)
// validateNames verifies that all names are allowed.
func (e *NamePolicyEngine) validateNames(dnsNames []string, ips []net.IP, emailAddresses []string, uris []*url.URL, principals []string) error {
// nothing to compare against; return early
if e.totalNumberOfConstraints == 0 {
return nil
}
// TODO: set limit on total of all names validated? In x509 there's a limit on the number of comparisons
// that protects the CA from a DoS (i.e. many heavy comparisons). The x509 implementation takes
// this number as a total of all checks and keeps a (pointer to a) counter of the number of checks
// executed so far.
// TODO: gather all errors, or return early? Currently we return early on the first wrong name; check might fail for multiple names.
// Perhaps make that an option?
for _, dns := range dnsNames {
// if there are DNS names to check, no DNS constraints set, but there are other permitted constraints,
// then return error, because DNS should be explicitly configured to be allowed in that case. In case there are
// (other) excluded constraints, we'll allow a DNS (implicit allow; currently).
if e.numberOfDNSDomainConstraints == 0 && e.totalNumberOfPermittedConstraints > 0 {
return &NamePolicyError{
Reason: NotAllowed,
NameType: DNSNameType,
Name: dns,
detail: fmt.Sprintf("dns %q is not explicitly permitted by any constraint", dns),
}
}
didCutWildcard := false
parsedDNS := dns
if strings.HasPrefix(parsedDNS, "*.") {
parsedDNS = parsedDNS[1:]
didCutWildcard = true
}
// TODO(hs): fix this above; we need separate rule for Subject Common Name?
parsedDNS, err := idna.Lookup.ToASCII(parsedDNS)
if err != nil {
return &NamePolicyError{
Reason: CannotParseDomain,
NameType: DNSNameType,
Name: dns,
detail: fmt.Sprintf("dns %q cannot be converted to ASCII", dns),
}
}
if didCutWildcard {
parsedDNS = "*" + parsedDNS
}
if _, ok := domainToReverseLabels(parsedDNS); !ok { // TODO(hs): this also fails with spaces
return &NamePolicyError{
Reason: CannotParseDomain,
NameType: DNSNameType,
Name: dns,
detail: fmt.Sprintf("cannot parse dns %q", dns),
}
}
if err := checkNameConstraints(DNSNameType, dns, parsedDNS,
func(parsedName, constraint interface{}) (bool, error) {
return e.matchDomainConstraint(parsedName.(string), constraint.(string))
}, e.permittedDNSDomains, e.excludedDNSDomains); err != nil {
return err
}
}
for _, ip := range ips {
if e.numberOfIPRangeConstraints == 0 && e.totalNumberOfPermittedConstraints > 0 {
return &NamePolicyError{
Reason: NotAllowed,
NameType: IPNameType,
Name: ip.String(),
detail: fmt.Sprintf("ip %q is not explicitly permitted by any constraint", ip.String()),
}
}
if err := checkNameConstraints(IPNameType, ip.String(), ip,
func(parsedName, constraint interface{}) (bool, error) {
return matchIPConstraint(parsedName.(net.IP), constraint.(*net.IPNet))
}, e.permittedIPRanges, e.excludedIPRanges); err != nil {
return err
}
}
for _, email := range emailAddresses {
if e.numberOfEmailAddressConstraints == 0 && e.totalNumberOfPermittedConstraints > 0 {
return &NamePolicyError{
Reason: NotAllowed,
NameType: EmailNameType,
Name: email,
detail: fmt.Sprintf("email %q is not explicitly permitted by any constraint", email),
}
}
mailbox, ok := parseRFC2821Mailbox(email)
if !ok {
return &NamePolicyError{
Reason: CannotParseRFC822Name,
NameType: EmailNameType,
Name: email,
detail: fmt.Sprintf("invalid rfc822Name %q", mailbox),
}
}
// According to RFC 5280, section 7.5, emails are considered to match if the local part is
// an exact match and the host (domain) part matches the ASCII representation (case-insensitive):
// https://datatracker.ietf.org/doc/html/rfc5280#section-7.5
domainASCII, err := idna.ToASCII(mailbox.domain)
if err != nil {
return &NamePolicyError{
Reason: CannotParseDomain,
NameType: EmailNameType,
Name: email,
detail: fmt.Errorf("cannot parse email domain %q: %w", email, err).Error(),
}
}
mailbox.domain = domainASCII
if err := checkNameConstraints(EmailNameType, email, mailbox,
func(parsedName, constraint interface{}) (bool, error) {
return e.matchEmailConstraint(parsedName.(rfc2821Mailbox), constraint.(string))
}, e.permittedEmailAddresses, e.excludedEmailAddresses); err != nil {
return err
}
}
// TODO(hs): fix internationalization for URIs (IRIs)
for _, uri := range uris {
if e.numberOfURIDomainConstraints == 0 && e.totalNumberOfPermittedConstraints > 0 {
return &NamePolicyError{
Reason: NotAllowed,
NameType: URINameType,
Name: uri.String(),
detail: fmt.Sprintf("uri %q is not explicitly permitted by any constraint", uri.String()),
}
}
// TODO(hs): ideally we'd like the uri.String() to be the original contents; now
// it's transformed into ASCII. Prevent that here?
if err := checkNameConstraints(URINameType, uri.String(), uri,
func(parsedName, constraint interface{}) (bool, error) {
return e.matchURIConstraint(parsedName.(*url.URL), constraint.(string))
}, e.permittedURIDomains, e.excludedURIDomains); err != nil {
return err
}
}
for _, principal := range principals {
if e.numberOfPrincipalConstraints == 0 && e.totalNumberOfPermittedConstraints > 0 {
return &NamePolicyError{
Reason: NotAllowed,
NameType: PrincipalNameType,
Name: principal,
detail: fmt.Sprintf("username principal %q is not explicitly permitted by any constraint", principal),
}
}
// TODO: some validation? I.e. allowed characters?
if err := checkNameConstraints(PrincipalNameType, principal, principal,
func(parsedName, constraint interface{}) (bool, error) {
return matchPrincipalConstraint(parsedName.(string), constraint.(string))
}, e.permittedPrincipals, e.excludedPrincipals); err != nil {
return err
}
}
// if all checks out, all SANs are allowed
return nil
}
// validateCommonName verifies that the Subject Common Name is allowed
func (e *NamePolicyEngine) validateCommonName(commonName string) error {
// nothing to compare against; return early
if e.totalNumberOfConstraints == 0 {
return nil
}
// empty common names are not validated
if commonName == "" {
return nil
}
if e.numberOfCommonNameConstraints > 0 {
// Check the Common Name using its dedicated matcher if constraints have been
// configured. If no error is returned from matching, the Common Name was
// explicitly allowed and nil is returned immediately.
if err := checkNameConstraints(CNNameType, commonName, commonName,
func(parsedName, constraint interface{}) (bool, error) {
return matchCommonNameConstraint(parsedName.(string), constraint.(string))
}, e.permittedCommonNames, e.excludedCommonNames); err == nil {
return nil
}
}
// When an error was returned or when no constraints were configured for Common Names,
// the Common Name should be validated against the other types of constraints too,
// according to what type it is.
dnsNames, ips, emails, uris := x509util.SplitSANs([]string{commonName})
err := e.validateNames(dnsNames, ips, emails, uris, []string{})
if pe, ok := err.(*NamePolicyError); ok {
// override the name type with CN
pe.NameType = CNNameType
}
return err
}
// checkNameConstraints checks that a name, of type nameType is permitted.
// The argument parsedName contains the parsed form of name, suitable for passing
// to the match function.
func checkNameConstraints(
nameType NameType,
name string,
parsedName interface{},
match func(parsedName, constraint interface{}) (match bool, err error),
permitted, excluded interface{}) error {
excludedValue := reflect.ValueOf(excluded)
for i := 0; i < excludedValue.Len(); i++ {
constraint := excludedValue.Index(i).Interface()
match, err := match(parsedName, constraint)
if err != nil {
return &NamePolicyError{
Reason: CannotMatchNameToConstraint,
NameType: nameType,
Name: name,
detail: err.Error(),
}
}
if match {
return &NamePolicyError{
Reason: NotAllowed,
NameType: nameType,
Name: name,
detail: fmt.Sprintf("%s %q is excluded by constraint %q", nameType, name, constraint),
}
}
}
permittedValue := reflect.ValueOf(permitted)
ok := true
for i := 0; i < permittedValue.Len(); i++ {
constraint := permittedValue.Index(i).Interface()
var err error
if ok, err = match(parsedName, constraint); err != nil {
return &NamePolicyError{
Reason: CannotMatchNameToConstraint,
NameType: nameType,
Name: name,
detail: err.Error(),
}
}
if ok {
break
}
}
if !ok {
return &NamePolicyError{
Reason: NotAllowed,
NameType: nameType,
Name: name,
detail: fmt.Sprintf("%s %q is not permitted by any constraint", nameType, name),
}
}
return nil
}
// domainToReverseLabels converts a textual domain name like foo.example.com to
// the list of labels in reverse order, e.g. ["com", "example", "foo"].
func domainToReverseLabels(domain string) (reverseLabels []string, ok bool) {
for len(domain) > 0 {
if i := strings.LastIndexByte(domain, '.'); i == -1 {
reverseLabels = append(reverseLabels, domain)
domain = ""
} else {
reverseLabels = append(reverseLabels, domain[i+1:])
domain = domain[:i]
}
}
if len(reverseLabels) > 0 && reverseLabels[0] == "" {
// An empty label at the end indicates an absolute value.
return nil, false
}
for _, label := range reverseLabels {
if label == "" {
// Empty labels are otherwise invalid.
return nil, false
}
for _, c := range label {
if c < 33 || c > 126 {
// Invalid character.
return nil, false
}
}
}
return reverseLabels, true
}
// rfc2821Mailbox represents a “mailbox” (which is an email address to most
// people) by breaking it into the “local” (i.e. before the '@') and “domain”
// parts.
type rfc2821Mailbox struct {
local, domain string
}
// parseRFC2821Mailbox parses an email address into local and domain parts,
// based on the ABNF for a “Mailbox” from RFC 2821. According to RFC 5280,
// Section 4.2.1.6 that's correct for an rfc822Name from a certificate: “The
// format of an rfc822Name is a "Mailbox" as defined in RFC 2821, Section 4.1.2”.
func parseRFC2821Mailbox(in string) (mailbox rfc2821Mailbox, ok bool) {
if in == "" {
return mailbox, false
}
localPartBytes := make([]byte, 0, len(in)/2)
if in[0] == '"' {
// Quoted-string = DQUOTE *qcontent DQUOTE
// non-whitespace-control = %d1-8 / %d11 / %d12 / %d14-31 / %d127
// qcontent = qtext / quoted-pair
// qtext = non-whitespace-control /
// %d33 / %d35-91 / %d93-126
// quoted-pair = ("\" text) / obs-qp
// text = %d1-9 / %d11 / %d12 / %d14-127 / obs-text
//
// (Names beginning with “obs-” are the obsolete syntax from RFC 2822,
// Section 4. Since it has been 16 years, we no longer accept that.)
in = in[1:]
QuotedString:
for {
if in == "" {
return mailbox, false
}
c := in[0]
in = in[1:]
switch {
case c == '"':
break QuotedString
case c == '\\':
// quoted-pair
if in == "" {
return mailbox, false
}
if in[0] == 11 ||
in[0] == 12 ||
(1 <= in[0] && in[0] <= 9) ||
(14 <= in[0] && in[0] <= 127) {
localPartBytes = append(localPartBytes, in[0])
in = in[1:]
} else {
return mailbox, false
}
case c == 11 ||
c == 12 ||
// Space (char 32) is not allowed based on the
// BNF, but RFC 3696 gives an example that
// assumes that it is. Several “verified”
// errata continue to argue about this point.
// We choose to accept it.
c == 32 ||
c == 33 ||
c == 127 ||
(1 <= c && c <= 8) ||
(14 <= c && c <= 31) ||
(35 <= c && c <= 91) ||
(93 <= c && c <= 126):
// qtext
localPartBytes = append(localPartBytes, c)
default:
return mailbox, false
}
}
} else {
// Atom ("." Atom)*
NextChar:
for len(in) > 0 {
// atext from RFC 2822, Section 3.2.4
c := in[0]
switch {
case c == '\\':
// Examples given in RFC 3696 suggest that
// escaped characters can appear outside of a
// quoted string. Several “verified” errata
// continue to argue the point. We choose to
// accept it.
in = in[1:]
if in == "" {
return mailbox, false
}
fallthrough
case ('0' <= c && c <= '9') ||
('a' <= c && c <= 'z') ||
('A' <= c && c <= 'Z') ||
c == '!' || c == '#' || c == '$' || c == '%' ||
c == '&' || c == '\'' || c == '*' || c == '+' ||
c == '-' || c == '/' || c == '=' || c == '?' ||
c == '^' || c == '_' || c == '`' || c == '{' ||
c == '|' || c == '}' || c == '~' || c == '.':
localPartBytes = append(localPartBytes, in[0])
in = in[1:]
default:
break NextChar
}
}
if len(localPartBytes) == 0 {
return mailbox, false
}
// From RFC 3696, Section 3:
// “period (".") may also appear, but may not be used to start
// or end the local part, nor may two or more consecutive
// periods appear.”
twoDots := []byte{'.', '.'}
if localPartBytes[0] == '.' ||
localPartBytes[len(localPartBytes)-1] == '.' ||
bytes.Contains(localPartBytes, twoDots) {
return mailbox, false
}
}
if in == "" || in[0] != '@' {
return mailbox, false
}
in = in[1:]
// The RFC species a format for domains, but that's known to be
// violated in practice so we accept that anything after an '@' is the
// domain part.
if _, ok := domainToReverseLabels(in); !ok {
return mailbox, false
}
mailbox.local = string(localPartBytes)
mailbox.domain = in
return mailbox, true
}
// matchDomainConstraint matches a domain against the given constraint
func (e *NamePolicyEngine) matchDomainConstraint(domain, constraint string) (bool, error) {
// The meaning of zero length constraints is not specified, but this
// code follows NSS and accepts them as matching everything.
if constraint == "" {
return true, nil
}
// A single whitespace seems to be considered a valid domain, but we don't allow it.
if domain == " " {
return false, nil
}
// Block domains that start with just a period
if domain[0] == '.' {
return false, nil
}
// Block wildcard domains that don't start with exactly "*." (i.e. double wildcards and such)
if domain[0] == '*' && domain[1] != '.' {
return false, nil
}
// Check if the domain starts with a wildcard and return early if not allowed
if strings.HasPrefix(domain, "*.") && !e.allowLiteralWildcardNames {
return false, nil
}
// Only allow asterisk at the start of the domain; we don't allow them as part of a domain label or as a (sub)domain label (currently)
if strings.LastIndex(domain, "*") > 0 {
return false, nil
}
// Don't allow constraints with empty labels in any position
if strings.Contains(constraint, "..") {
return false, nil
}
domainLabels, ok := domainToReverseLabels(domain)
if !ok {
return false, fmt.Errorf("cannot parse domain %q", domain)
}
// RFC 5280 says that a leading period in a domain name means that at
// least one label must be prepended, but only for URI and email
// constraints, not DNS constraints. The code also supports that
// behavior for DNS constraints. In our adaptation of the original
// Go stdlib x509 Name Constraint implementation we look for exactly
// one subdomain, currently.
mustHaveSubdomains := false
if constraint[0] == '.' {
mustHaveSubdomains = true
constraint = constraint[1:]
}
constraintLabels, ok := domainToReverseLabels(constraint)
if !ok {
return false, fmt.Errorf("cannot parse domain constraint %q", constraint)
}
expectedNumberOfLabels := len(constraintLabels)
if mustHaveSubdomains {
// we expect exactly one more label if it starts with the "canonical" x509 "wildcard": "."
// in the future we could extend this to support multiple additional labels and/or more
// complex matching.
expectedNumberOfLabels++
}
if len(domainLabels) != expectedNumberOfLabels {
return false, nil
}
for i, constraintLabel := range constraintLabels {
if !strings.EqualFold(constraintLabel, domainLabels[i]) {
return false, nil
}
}
return true, nil
}
// SOURCE: https://cs.opensource.google/go/go/+/refs/tags/go1.17.5:src/crypto/x509/verify.go
func matchIPConstraint(ip net.IP, constraint *net.IPNet) (bool, error) {
// TODO(hs): this is code from Go library, but I got some unexpected result:
// with permitted net 127.0.0.0/24, 127.0.0.1 is NOT allowed. When parsing 127.0.0.1 as net.IP
// which is in the IPAddresses slice, the underlying length is 16. The contraint.IP has a length
// of 4 instead. I currently don't believe that this is a bug in Go now, but why is it like that?
// Is there a difference because we're not operating on a sans []string slice? Or is the Go
// implementation stricter regarding IPv4 vs. IPv6? I've been bitten by some unfortunate differences
// between the two before (i.e. IPv4 in IPv6; IP SANS in ACME)
// if len(ip) != len(constraint.IP) {
// return false, nil
// }
// for i := range ip {
// if mask := constraint.Mask[i]; ip[i]&mask != constraint.IP[i]&mask {
// return false, nil
// }
// }
contained := constraint.Contains(ip) // TODO(hs): validate that this is the correct behavior; also check IPv4-in-IPv6 (again)
return contained, nil
}
// SOURCE: https://cs.opensource.google/go/go/+/refs/tags/go1.17.5:src/crypto/x509/verify.go
func (e *NamePolicyEngine) matchEmailConstraint(mailbox rfc2821Mailbox, constraint string) (bool, error) {
if strings.Contains(constraint, "@") {
constraintMailbox, ok := parseRFC2821Mailbox(constraint)
if !ok {
return false, fmt.Errorf("cannot parse constraint %q", constraint)
}
return mailbox.local == constraintMailbox.local && strings.EqualFold(mailbox.domain, constraintMailbox.domain), nil
}
// Otherwise the constraint is like a DNS constraint of the domain part
// of the mailbox.
return e.matchDomainConstraint(mailbox.domain, constraint)
}
// matchURIConstraint matches an URL against a constraint
func (e *NamePolicyEngine) matchURIConstraint(uri *url.URL, constraint string) (bool, error) {
// From RFC 5280, Section 4.2.1.10:
// “a uniformResourceIdentifier that does not include an authority
// component with a host name specified as a fully qualified domain
// name (e.g., if the URI either does not include an authority
// component or includes an authority component in which the host name
// is specified as an IP address), then the application MUST reject the
// certificate.”
host := uri.Host
if host == "" {
return false, fmt.Errorf("URI with empty host (%q) cannot be matched against constraints", uri.String())
}
// Block hosts with the wildcard character; no exceptions, also not when wildcards allowed.
if strings.Contains(host, "*") {
return false, fmt.Errorf("URI host %q cannot contain asterisk", uri.String())
}
if strings.Contains(host, ":") && !strings.HasSuffix(host, "]") {
var err error
host, _, err = net.SplitHostPort(host)
if err != nil {
return false, err
}
}
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") ||
net.ParseIP(host) != nil {
return false, fmt.Errorf("URI with IP %q cannot be matched against constraints", uri.String())
}
// TODO(hs): add checks for scheme, path, etc.; either here, or in a different constraint matcher (to keep this one simple)
return e.matchDomainConstraint(host, constraint)
}
// matchPrincipalConstraint performs a string literal equality check against a constraint.
func matchPrincipalConstraint(principal, constraint string) (bool, error) {
// allow any plain principal when wildcard constraint is used
if constraint == "*" {
return true, nil
}
return strings.EqualFold(principal, constraint), nil
}
// matchCommonNameConstraint performs a string literal equality check against constraint.
func matchCommonNameConstraint(commonName, constraint string) (bool, error) {
// wildcard constraint is (currently) not supported for common names
if constraint == "*" {
return false, nil
}
return strings.EqualFold(commonName, constraint), nil
}

@ -0,0 +1,14 @@
package policy
import (
"crypto/x509"
"net"
)
type X509NamePolicyEngine interface {
IsX509CertificateAllowed(cert *x509.Certificate) error
IsX509CertificateRequestAllowed(csr *x509.CertificateRequest) error
AreSANsAllowed(sans []string) error
IsDNSAllowed(dns string) error
IsIPAllowed(ip net.IP) error
}
Loading…
Cancel
Save