Merge branch 'master' into ahmet2mir-feat/vault

pull/894/head
Mariano Cano 2 years ago committed by GitHub
commit fe9c3cf753
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -6,15 +6,17 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased - 0.18.3] - DATE
### Added
- Added support for renew after expiry using the claim `allowRenewAfterExpiry`.
- Added support for certificate renewals after expiry using the claim `allowRenewalAfterExpiry`.
- Added support for `extraNames` in X.509 templates.
- Added RA support using a Vault instance as the CA.
- Added support for automatic configuration of linked RAs.
### Changed
- Made SCEP CA URL paths dynamic
- Support two latest versions of Go (1.17, 1.18)
### Deprecated
### Removed
### Fixed
- Fixed admin credentials on RAs.
### Security
## [0.18.2] - 2022-03-01

@ -54,7 +54,7 @@ Setting up a *public key infrastructure* (PKI) is out of reach for many small te
- [Short-lived certificates](https://smallstep.com/blog/passive-revocation.html) with automated enrollment, renewal, and passive revocation
- Capable of high availability (HA) deployment using [root federation](https://smallstep.com/blog/step-v0.8.3-federation-root-rotation.html) and/or multiple intermediaries
- Can operate as [an online intermediate CA for an existing root CA](https://smallstep.com/docs/tutorials/intermediate-ca-new-ca)
- [Badger, BoltDB, and MySQL database backends](https://smallstep.com/docs/step-ca/configuration#databases)
- [Badger, BoltDB, Postgres, and MySQL database backends](https://smallstep.com/docs/step-ca/configuration#databases)
### ⚙️ Many ways to automate

@ -130,22 +130,24 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc
// According to "rfc7519 JSON Web Token" acceptable skew should be no
// more than a few minutes.
if err := claims.ValidateWithLeeway(jose.Expected{
Issuer: prov.GetName(),
Time: time.Now().UTC(),
Time: time.Now().UTC(),
}, time.Minute); err != nil {
return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "x5c.authorizeToken; invalid x5c claims")
}
// validate audience: path matches the current path
if r.URL.Path != claims.Audience[0] {
return nil, admin.NewError(admin.ErrorUnauthorizedType,
"x5c.authorizeToken; x5c token has invalid audience "+
"claim (aud); expected %s, but got %s", r.URL.Path, claims.Audience)
if !matchesAudience(claims.Audience, a.config.Audience(r.URL.Path)) {
return nil, admin.NewError(admin.ErrorUnauthorizedType, "x5c.authorizeToken; x5c token has invalid audience claim (aud)")
}
// validate issuer: old versions used the provisioner name, new version uses
// 'step-admin-client/1.0'
if claims.Issuer != "step-admin-client/1.0" && claims.Issuer != prov.GetName() {
return nil, admin.NewError(admin.ErrorUnauthorizedType, "x5c.authorizeToken; x5c token has invalid issuer claim (iss)")
}
if claims.Subject == "" {
return nil, admin.NewError(admin.ErrorUnauthorizedType,
"x5c.authorizeToken; x5c token subject cannot be empty")
return nil, admin.NewError(admin.ErrorUnauthorizedType, "x5c.authorizeToken; x5c token subject cannot be empty")
}
var (
@ -156,7 +158,7 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc
adminSANs := append([]string{leaf.Subject.CommonName}, leaf.DNSNames...)
adminSANs = append(adminSANs, leaf.EmailAddresses...)
for _, san := range adminSANs {
if adm, ok = a.LoadAdminBySubProv(san, claims.Issuer); ok {
if adm, ok = a.LoadAdminBySubProv(san, prov.GetName()); ok {
adminFound = true
break
}
@ -285,9 +287,16 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
if isRevoked {
return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...)
}
p, ok := a.provisioners.LoadByCertificate(cert)
if !ok {
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
p, err := a.LoadProvisionerByCertificate(cert)
if err != nil {
var ok bool
// For backward compatibility this method will also succeed if the
// certificate does not have a provisioner extension. LoadByCertificate
// returns the noop provisioner if this happens, and it allows
// certificate renewals.
if p, ok = a.provisioners.LoadByCertificate(cert); !ok {
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
}
}
if err := p.AuthorizeRenew(context.Background(), cert); err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
@ -386,8 +395,8 @@ func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.
return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token"))
}
p, ok := a.provisioners.LoadByCertificate(leaf)
if !ok {
p, err := a.LoadProvisionerByCertificate(leaf)
if err != nil {
return nil, errs.Unauthorized("error validating renew token: cannot get provisioner from certificate")
}
if err := a.UseToken(ott, p); err != nil {
@ -395,7 +404,6 @@ func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.
}
if err := claims.ValidateWithLeeway(jose.Expected{
Issuer: p.GetName(),
Subject: leaf.Subject.CommonName,
Time: time.Now().UTC(),
}, time.Minute); err != nil {
@ -420,6 +428,12 @@ func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.
return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token: invalid audience claim (aud)"))
}
// validate issuer: old versions used the provisioner name, new version uses
// 'step-ca-client/1.0'
if claims.Issuer != "step-ca-client/1.0" && claims.Issuer != p.GetName() {
return nil, admin.NewError(admin.ErrorUnauthorizedType, "error validating renew token: invalid issuer claim (iss)")
}
return leaf, nil
}

@ -847,6 +847,29 @@ func TestAuthority_authorizeRenew(t *testing.T) {
cert: fooCrt,
}
},
"ok/from db": func(t *testing.T) *authorizeTest {
a := testAuthority(t)
a.db = &db.MockAuthDB{
MIsRevoked: func(key string) (bool, error) {
return false, nil
},
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
p, ok := a.provisioners.LoadByName("step-cli")
if !ok {
t.Fatal("provisioner step-cli not found")
}
return &db.CertificateData{
Provisioner: &db.ProvisionerData{
ID: p.GetID(),
},
}, nil
},
}
return &authorizeTest{
auth: a,
cert: fooCrt,
}
},
}
for name, genTestCase := range tests {
@ -1381,7 +1404,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
t1, c1 := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-cli",
Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
@ -1400,7 +1423,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
t2, c2 := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-cli",
Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
IssuedAt: jose.NewNumericDate(now),
@ -1417,12 +1440,31 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
})
return nil
}))
badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{
t3, c3 := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
@ -1439,7 +1481,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
badProvisioner, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-cli",
Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
@ -1477,7 +1519,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
badSubject, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "bad-subject",
Issuer: "step-cli",
Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
@ -1496,7 +1538,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
badNotBefore, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com",
Issuer: "step-cli",
Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now.Add(5 * time.Minute)),
Expiry: jose.NewNumericDate(now.Add(10 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
@ -1515,7 +1557,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
badExpiry, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com",
Issuer: "step-cli",
Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now.Add(-5 * time.Minute)),
Expiry: jose.NewNumericDate(now.Add(-time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
@ -1534,7 +1576,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
badIssuedAt, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com",
Issuer: "step-cli",
Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
IssuedAt: jose.NewNumericDate(now.Add(5 * time.Minute)),
@ -1554,7 +1596,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
badAudience, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com",
Issuer: "step-cli",
Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
@ -1584,6 +1626,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
}{
{"ok", a1, args{ctx, t1}, c1, false},
{"ok expired cert", a1, args{ctx, t2}, c2, false},
{"ok provisioner issuer", a1, args{ctx, t3}, c3, false},
{"fail token", a1, args{ctx, "not.a.token"}, nil, true},
{"fail token reuse", a1, args{ctx, t1}, nil, true},
{"fail token signature", a1, args{ctx, badSigner}, nil, true},

@ -26,27 +26,27 @@ var (
DefaultBackdate = time.Minute
// DefaultDisableRenewal disables renewals per provisioner.
DefaultDisableRenewal = false
// DefaultAllowRenewAfterExpiry allows renewals even if the certificate is
// DefaultAllowRenewalAfterExpiry allows renewals even if the certificate is
// expired.
DefaultAllowRenewAfterExpiry = false
DefaultAllowRenewalAfterExpiry = false
// DefaultEnableSSHCA enable SSH CA features per provisioner or globally
// for all provisioners.
DefaultEnableSSHCA = false
// GlobalProvisionerClaims default claims for the Authority. Can be overridden
// by provisioner specific claims.
GlobalProvisionerClaims = provisioner.Claims{
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs
MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour},
MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs
MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
EnableSSHCA: &DefaultEnableSSHCA,
DisableRenewal: &DefaultDisableRenewal,
AllowRenewAfterExpiry: &DefaultAllowRenewAfterExpiry,
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs
MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour},
MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs
MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
EnableSSHCA: &DefaultEnableSSHCA,
DisableRenewal: &DefaultDisableRenewal,
AllowRenewalAfterExpiry: &DefaultAllowRenewalAfterExpiry,
}
)
@ -308,6 +308,18 @@ func (c *Config) GetAudiences() provisioner.Audiences {
return audiences
}
// Audience returns the list of audiences for a given path.
func (c *Config) Audience(path string) []string {
audiences := make([]string, len(c.DNSNames)+1)
for i, name := range c.DNSNames {
hostname := toHostname(name)
audiences[i] = "https://" + hostname + path
}
// For backward compatibility
audiences[len(c.DNSNames)] = path
return audiences
}
func toHostname(name string) string {
// ensure an IPv6 address is represented with square brackets when used as hostname
if ip := net.ParseIP(name); ip != nil && ip.To4() == nil {

@ -2,6 +2,7 @@ package config
import (
"fmt"
"reflect"
"testing"
"github.com/pkg/errors"
@ -317,3 +318,38 @@ func Test_toHostname(t *testing.T) {
})
}
}
func TestConfig_Audience(t *testing.T) {
type fields struct {
DNSNames []string
}
type args struct {
path string
}
tests := []struct {
name string
fields fields
args args
want []string
}{
{"ok", fields{[]string{
"ca", "ca.example.com", "127.0.0.1", "::1",
}}, args{"/path"}, []string{
"https://ca/path",
"https://ca.example.com/path",
"https://127.0.0.1/path",
"https://[::1]/path",
"/path",
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Config{
DNSNames: tt.fields.DNSNames,
}
if got := c.Audience(tt.args.path); !reflect.DeepEqual(got, tt.want) {
t.Errorf("Config.Audience() = %v, want %v", got, tt.want)
}
})
}
}

@ -235,6 +235,28 @@ func (c *linkedCaClient) DeleteAdmin(ctx context.Context, id string) error {
return errors.Wrap(err, "error deleting admin")
}
func (c *linkedCaClient) GetCertificateData(serial string) (*db.CertificateData, error) {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
resp, err := c.client.GetCertificate(ctx, &linkedca.GetCertificateRequest{
Serial: serial,
})
if err != nil {
return nil, err
}
var pd *db.ProvisionerData
if p := resp.Provisioner; p != nil {
pd = &db.ProvisionerData{
ID: p.Id, Name: p.Name, Type: p.Type.String(),
}
}
return &db.CertificateData{
Provisioner: pd,
}, nil
}
func (c *linkedCaClient) StoreCertificateChain(prov provisioner.Interface, fullchain ...*x509.Certificate) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()

@ -24,8 +24,8 @@ type Claims struct {
EnableSSHCA *bool `json:"enableSSHCA,omitempty"`
// Renewal properties
DisableRenewal *bool `json:"disableRenewal,omitempty"`
AllowRenewAfterExpiry *bool `json:"allowRenewAfterExpiry,omitempty"`
DisableRenewal *bool `json:"disableRenewal,omitempty"`
AllowRenewalAfterExpiry *bool `json:"allowRenewalAfterExpiry,omitempty"`
}
// Claimer is the type that controls claims. It provides an interface around the
@ -44,22 +44,22 @@ func NewClaimer(claims *Claims, global Claims) (*Claimer, error) {
// Claims returns the merge of the inner and global claims.
func (c *Claimer) Claims() Claims {
disableRenewal := c.IsDisableRenewal()
allowRenewAfterExpiry := c.AllowRenewAfterExpiry()
allowRenewalAfterExpiry := c.AllowRenewalAfterExpiry()
enableSSHCA := c.IsSSHCAEnabled()
return Claims{
MinTLSDur: &Duration{c.MinTLSCertDuration()},
MaxTLSDur: &Duration{c.MaxTLSCertDuration()},
DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()},
MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()},
MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()},
DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()},
MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()},
MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()},
DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()},
EnableSSHCA: &enableSSHCA,
DisableRenewal: &disableRenewal,
AllowRenewAfterExpiry: &allowRenewAfterExpiry,
MinTLSDur: &Duration{c.MinTLSCertDuration()},
MaxTLSDur: &Duration{c.MaxTLSCertDuration()},
DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()},
MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()},
MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()},
DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()},
MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()},
MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()},
DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()},
EnableSSHCA: &enableSSHCA,
DisableRenewal: &disableRenewal,
AllowRenewalAfterExpiry: &allowRenewalAfterExpiry,
}
}
@ -109,14 +109,14 @@ func (c *Claimer) IsDisableRenewal() bool {
return *c.claims.DisableRenewal
}
// AllowRenewAfterExpiry returns if the renewal flow is authorized if the
// AllowRenewalAfterExpiry returns if the renewal flow is authorized if the
// certificate is expired. If the property is not set within the provisioner
// then the global value from the authority configuration will be used.
func (c *Claimer) AllowRenewAfterExpiry() bool {
if c.claims == nil || c.claims.AllowRenewAfterExpiry == nil {
return *c.global.AllowRenewAfterExpiry
func (c *Claimer) AllowRenewalAfterExpiry() bool {
if c.claims == nil || c.claims.AllowRenewalAfterExpiry == nil {
return *c.global.AllowRenewalAfterExpiry
}
return *c.claims.AllowRenewAfterExpiry
return *c.claims.AllowRenewalAfterExpiry
}
// DefaultSSHCertDuration returns the default SSH certificate duration for the

@ -124,7 +124,7 @@ func DefaultAuthorizeRenew(ctx context.Context, p *Controller, cert *x509.Certif
if now.Before(cert.NotBefore) {
return errs.Unauthorized("certificate is not yet valid" + " " + now.UTC().Format(time.RFC3339Nano) + " vs " + cert.NotBefore.Format(time.RFC3339Nano))
}
if now.After(cert.NotAfter) && !p.Claimer.AllowRenewAfterExpiry() {
if now.After(cert.NotAfter) && !p.Claimer.AllowRenewalAfterExpiry() {
return errs.Unauthorized("certificate has expired")
}
@ -144,7 +144,7 @@ func DefaultAuthorizeSSHRenew(ctx context.Context, p *Controller, cert *ssh.Cert
if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) {
return errs.Unauthorized("certificate is not yet valid")
}
if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) && !p.Claimer.AllowRenewAfterExpiry() {
if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) && !p.Claimer.AllowRenewalAfterExpiry() {
return errs.Unauthorized("certificate has expired")
}

@ -160,13 +160,13 @@ func TestController_AuthorizeRenew(t *testing.T) {
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, false},
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
return nil
}}, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, false},
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute),
}}, false},
@ -231,13 +231,13 @@ func TestController_AuthorizeSSHRenew(t *testing.T) {
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, false},
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
return nil
}}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, false},
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
}}, false},
@ -296,7 +296,7 @@ func TestDefaultAuthorizeRenew(t *testing.T) {
}}, false},
{"ok renew after expiry", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims),
Claimer: mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims),
}, &x509.Certificate{
NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute),
@ -354,7 +354,7 @@ func TestDefaultAuthorizeSSHRenew(t *testing.T) {
}}, false},
{"ok renew after expiry", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims),
Claimer: mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims),
}, &ssh.Certificate{
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
ValidBefore: uint64(now.Add(-time.Minute).Unix()),

@ -9,7 +9,6 @@ import (
"strings"
"github.com/pkg/errors"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh"
)
@ -212,8 +211,6 @@ type Config struct {
Claims Claims
// Audiences are the audiences used in the default provisioner, (JWK).
Audiences Audiences
// DB is the interface to the authority DB client.
DB db.AuthDB
// SSHKeys are the root SSH public keys
SSHKeys *SSHKeys
// GetIdentityFunc is a function that returns an identity that will be

@ -24,22 +24,22 @@ import (
)
var (
defaultDisableRenewal = false
defaultAllowRenewAfterExpiry = false
defaultEnableSSHCA = true
globalProvisionerClaims = Claims{
MinTLSDur: &Duration{5 * time.Minute},
MaxTLSDur: &Duration{24 * time.Hour},
DefaultTLSDur: &Duration{24 * time.Hour},
MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs
MaxUserSSHDur: &Duration{Duration: 24 * time.Hour},
DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour},
MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs
MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
EnableSSHCA: &defaultEnableSSHCA,
DisableRenewal: &defaultDisableRenewal,
AllowRenewAfterExpiry: &defaultAllowRenewAfterExpiry,
defaultDisableRenewal = false
defaultAllowRenewalAfterExpiry = false
defaultEnableSSHCA = true
globalProvisionerClaims = Claims{
MinTLSDur: &Duration{5 * time.Minute},
MaxTLSDur: &Duration{24 * time.Hour},
DefaultTLSDur: &Duration{24 * time.Hour},
MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs
MaxUserSSHDur: &Duration{Duration: 24 * time.Hour},
DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour},
MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs
MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
EnableSSHCA: &defaultEnableSSHCA,
DisableRenewal: &defaultDisableRenewal,
AllowRenewalAfterExpiry: &defaultAllowRenewalAfterExpiry,
}
testAudiences = Audiences{
Sign: []string{"https://ca.smallstep.com/1.0/sign", "https://ca.smallstep.com/sign"},

@ -13,6 +13,7 @@ import (
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"go.step.sm/cli-utils/step"
"go.step.sm/cli-utils/ui"
@ -46,13 +47,43 @@ func (a *Authority) GetProvisioners(cursor string, limit int) (provisioner.List,
func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisioner.Interface, error) {
a.adminMutex.RLock()
defer a.adminMutex.RUnlock()
if p, err := a.unsafeLoadProvisionerFromDatabase(crt); err == nil {
return p, nil
}
return a.unsafeLoadProvisionerFromExtension(crt)
}
func (a *Authority) unsafeLoadProvisionerFromExtension(crt *x509.Certificate) (provisioner.Interface, error) {
p, ok := a.provisioners.LoadByCertificate(crt)
if !ok {
if !ok || p.GetType() == 0 {
return nil, admin.NewError(admin.ErrorNotFoundType, "unable to load provisioner from certificate")
}
return p, nil
}
func (a *Authority) unsafeLoadProvisionerFromDatabase(crt *x509.Certificate) (provisioner.Interface, error) {
// certificateDataGetter is an interface that can be used to retrieve the
// provisioner from a db or a linked ca.
type certificateDataGetter interface {
GetCertificateData(string) (*db.CertificateData, error)
}
var err error
var data *db.CertificateData
if cdg, ok := a.adminDB.(certificateDataGetter); ok {
data, err = cdg.GetCertificateData(crt.SerialNumber.String())
} else if cdg, ok := a.db.(certificateDataGetter); ok {
data, err = cdg.GetCertificateData(crt.SerialNumber.String())
}
if err == nil && data != nil && data.Provisioner != nil {
if p, ok := a.provisioners.Load(data.Provisioner.ID); ok {
return p, nil
}
}
return nil, admin.NewError(admin.ErrorNotFoundType, "unable to load provisioner from certificate")
}
// LoadProvisionerByToken returns an interface to the provisioner that
// provisioned the token.
func (a *Authority) LoadProvisionerByToken(token *jwt.JSONWebToken, claims *jwt.Claims) (provisioner.Interface, error) {
@ -103,7 +134,6 @@ func (a *Authority) generateProvisionerConfig(ctx context.Context) (provisioner.
return provisioner.Config{
Claims: claimer.Claims(),
Audiences: a.config.GetAudiences(),
DB: a.db,
SSHKeys: &provisioner.SSHKeys{
UserKeys: sshKeys.UserKeys,
HostKeys: sshKeys.HostKeys,
@ -247,7 +277,7 @@ func (a *Authority) RemoveProvisioner(ctx context.Context, id string) error {
// CreateFirstProvisioner creates and stores the first provisioner when using
// admin database provisioner storage.
func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) (*linkedca.Provisioner, error) {
func CreateFirstProvisioner(ctx context.Context, adminDB admin.DB, password string) (*linkedca.Provisioner, error) {
if password == "" {
pass, err := ui.PromptPasswordGenerate("Please enter the password to encrypt your first provisioner, leave empty and we'll generate one")
if err != nil {
@ -290,7 +320,7 @@ func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) (
},
},
}
if err := db.CreateProvisioner(ctx, p); err != nil {
if err := adminDB.CreateProvisioner(ctx, p); err != nil {
return nil, admin.WrapErrorISE(err, "error creating provisioner")
}
return p, nil
@ -437,8 +467,8 @@ func claimsToCertificates(c *linkedca.Claims) (*provisioner.Claims, error) {
}
pc := &provisioner.Claims{
DisableRenewal: &c.DisableRenewal,
AllowRenewAfterExpiry: &c.AllowRenewAfterExpiry,
DisableRenewal: &c.DisableRenewal,
AllowRenewalAfterExpiry: &c.AllowRenewalAfterExpiry,
}
var err error
@ -476,18 +506,18 @@ func claimsToLinkedca(c *provisioner.Claims) *linkedca.Claims {
}
disableRenewal := config.DefaultDisableRenewal
allowRenewAfterExpiry := config.DefaultAllowRenewAfterExpiry
allowRenewalAfterExpiry := config.DefaultAllowRenewalAfterExpiry
if c.DisableRenewal != nil {
disableRenewal = *c.DisableRenewal
}
if c.AllowRenewAfterExpiry != nil {
allowRenewAfterExpiry = *c.AllowRenewAfterExpiry
if c.AllowRenewalAfterExpiry != nil {
allowRenewalAfterExpiry = *c.AllowRenewalAfterExpiry
}
lc := &linkedca.Claims{
DisableRenewal: disableRenewal,
AllowRenewAfterExpiry: allowRenewAfterExpiry,
DisableRenewal: disableRenewal,
AllowRenewalAfterExpiry: allowRenewalAfterExpiry,
}
if c.DefaultTLSDur != nil || c.MinTLSDur != nil || c.MaxTLSDur != nil {

@ -1,13 +1,21 @@
package authority
import (
"context"
"crypto/x509"
"errors"
"net/http"
"reflect"
"testing"
"time"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
)
func TestGetEncryptedKey(t *testing.T) {
@ -67,6 +75,15 @@ func TestGetEncryptedKey(t *testing.T) {
}
}
type mockAdminDB struct {
admin.MockDB
MGetCertificateData func(string) (*db.CertificateData, error)
}
func (c *mockAdminDB) GetCertificateData(sn string) (*db.CertificateData, error) {
return c.MGetCertificateData(sn)
}
func TestGetProvisioners(t *testing.T) {
type gp struct {
a *Authority
@ -104,3 +121,133 @@ func TestGetProvisioners(t *testing.T) {
})
}
}
func TestAuthority_LoadProvisionerByCertificate(t *testing.T) {
_, priv, err := keyutil.GenerateDefaultKeyPair()
assert.FatalError(t, err)
csr := getCSR(t, priv)
sign := func(a *Authority, extraOpts ...provisioner.SignOption) *x509.Certificate {
key, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
assert.FatalError(t, err)
token, err := generateToken("smallstep test", "step-cli", testAudiences.Sign[0], []string{"test.smallstep.com"}, time.Now(), key)
assert.FatalError(t, err)
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
opts, err := a.Authorize(ctx, token)
assert.FatalError(t, err)
opts = append(opts, extraOpts...)
certs, err := a.Sign(csr, provisioner.SignOptions{}, opts...)
assert.FatalError(t, err)
return certs[0]
}
getProvisioner := func(a *Authority, name string) provisioner.Interface {
p, ok := a.provisioners.LoadByName(name)
if !ok {
t.Fatalf("provisioner %s does not exists", name)
}
return p
}
removeExtension := provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
for i, ext := range cert.ExtraExtensions {
if ext.Id.Equal(provisioner.StepOIDProvisioner) {
cert.ExtraExtensions = append(cert.ExtraExtensions[:i], cert.ExtraExtensions[i+1:]...)
break
}
}
return nil
})
a0 := testAuthority(t)
a1 := testAuthority(t)
a1.db = &db.MockAuthDB{
MUseToken: func(id, tok string) (bool, error) {
return true, nil
},
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
p, err := a1.LoadProvisionerByName("dev")
if err != nil {
t.Fatal(err)
}
return &db.CertificateData{
Provisioner: &db.ProvisionerData{
ID: p.GetID(),
Name: p.GetName(),
Type: p.GetType().String(),
},
}, nil
},
}
a2 := testAuthority(t)
a2.adminDB = &mockAdminDB{
MGetCertificateData: (func(s string) (*db.CertificateData, error) {
p, err := a2.LoadProvisionerByName("dev")
if err != nil {
t.Fatal(err)
}
return &db.CertificateData{
Provisioner: &db.ProvisionerData{
ID: p.GetID(),
Name: p.GetName(),
Type: p.GetType().String(),
},
}, nil
}),
}
a3 := testAuthority(t)
a3.db = &db.MockAuthDB{
MUseToken: func(id, tok string) (bool, error) {
return true, nil
},
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
return &db.CertificateData{
Provisioner: &db.ProvisionerData{
ID: "foo", Name: "foo", Type: "foo",
},
}, nil
},
}
a4 := testAuthority(t)
a4.adminDB = &mockAdminDB{
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
return &db.CertificateData{
Provisioner: &db.ProvisionerData{
ID: "foo", Name: "foo", Type: "foo",
},
}, nil
},
}
type args struct {
crt *x509.Certificate
}
tests := []struct {
name string
authority *Authority
args args
want provisioner.Interface
wantErr bool
}{
{"ok from certificate", a0, args{sign(a0)}, getProvisioner(a0, "step-cli"), false},
{"ok from db", a1, args{sign(a1)}, getProvisioner(a1, "dev"), false},
{"ok from admindb", a2, args{sign(a2)}, getProvisioner(a2, "dev"), false},
{"fail from certificate", a0, args{sign(a0, removeExtension)}, nil, true},
{"fail from db", a3, args{sign(a3, removeExtension)}, nil, true},
{"fail from admindb", a4, args{sign(a4, removeExtension)}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.authority.LoadProvisionerByCertificate(tt.args.crt)
if (err != nil) != tt.wantErr {
t.Errorf("Authority.LoadProvisionerByCertificate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Authority.LoadProvisionerByCertificate() = %v, want %v", got, tt.want)
}
})
}
}

@ -347,6 +347,8 @@ func (a *Authority) storeCertificate(prov provisioner.Interface, fullchain []*x5
// Store certificate in local db
switch s := a.db.(type) {
case linkedChainStorer:
return s.StoreCertificateChain(prov, fullchain...)
case certificateChainStorer:
return s.StoreCertificateChain(fullchain...)
default:

@ -23,7 +23,10 @@ import (
"google.golang.org/protobuf/encoding/protojson"
)
var adminURLPrefix = "admin"
const (
adminURLPrefix = "admin"
adminIssuer = "step-admin-client/1.0"
)
// AdminClient implements an HTTP client for the CA server.
type AdminClient struct {
@ -35,7 +38,6 @@ type AdminClient struct {
x5cCertFile string
x5cCertStrs []string
x5cCert *x509.Certificate
x5cIssuer string
x5cSubject string
}
@ -77,24 +79,30 @@ func NewAdminClient(endpoint string, opts ...ClientOption) (*AdminClient, error)
x5cCertFile: o.x5cCertFile,
x5cCertStrs: o.x5cCertStrs,
x5cCert: o.x5cCert,
x5cIssuer: o.x5cIssuer,
x5cSubject: o.x5cSubject,
}, nil
}
func (c *AdminClient) generateAdminToken(urlPath string) (string, error) {
func (c *AdminClient) generateAdminToken(aud *url.URL) (string, error) {
// A random jwt id will be used to identify duplicated tokens
jwtID, err := randutil.Hex(64) // 256 bits
if err != nil {
return "", err
}
// Drop any query string parameter from the token audience
aud = &url.URL{
Scheme: aud.Scheme,
Host: aud.Host,
Path: aud.Path,
}
now := time.Now()
tokOptions := []token.Options{
token.WithJWTID(jwtID),
token.WithKid(c.x5cJWK.KeyID),
token.WithIssuer(c.x5cIssuer),
token.WithAudience(urlPath),
token.WithIssuer(adminIssuer),
token.WithAudience(aud.String()),
token.WithValidity(now, now.Add(token.DefaultValidity)),
token.WithX5CCerts(c.x5cCertStrs),
}
@ -205,7 +213,7 @@ func (c *AdminClient) GetAdminsPaginate(opts ...AdminOption) (*adminAPI.GetAdmin
Path: "/admin/admins",
RawQuery: o.rawQuery(),
})
tok, err := c.generateAdminToken(u.Path)
tok, err := c.generateAdminToken(u)
if err != nil {
return nil, errors.Wrapf(err, "error generating admin token")
}
@ -260,7 +268,7 @@ func (c *AdminClient) CreateAdmin(createAdminRequest *adminAPI.CreateAdminReques
return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request")
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/admin/admins"})
tok, err := c.generateAdminToken(u.Path)
tok, err := c.generateAdminToken(u)
if err != nil {
return nil, errors.Wrapf(err, "error generating admin token")
}
@ -292,7 +300,7 @@ retry:
func (c *AdminClient) RemoveAdmin(id string) error {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "admins", id)})
tok, err := c.generateAdminToken(u.Path)
tok, err := c.generateAdminToken(u)
if err != nil {
return errors.Wrapf(err, "error generating admin token")
}
@ -324,7 +332,7 @@ func (c *AdminClient) UpdateAdmin(id string, uar *adminAPI.UpdateAdminRequest) (
return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request")
}
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "admins", id)})
tok, err := c.generateAdminToken(u.Path)
tok, err := c.generateAdminToken(u)
if err != nil {
return nil, errors.Wrapf(err, "error generating admin token")
}
@ -371,7 +379,7 @@ func (c *AdminClient) GetProvisioner(opts ...ProvisionerOption) (*linkedca.Provi
default:
return nil, errors.New("must set either name or id in method options")
}
tok, err := c.generateAdminToken(u.Path)
tok, err := c.generateAdminToken(u)
if err != nil {
return nil, errors.Wrapf(err, "error generating admin token")
}
@ -410,7 +418,7 @@ func (c *AdminClient) GetProvisionersPaginate(opts ...ProvisionerOption) (*admin
Path: "/admin/provisioners",
RawQuery: o.rawQuery(),
})
tok, err := c.generateAdminToken(u.Path)
tok, err := c.generateAdminToken(u)
if err != nil {
return nil, errors.Wrapf(err, "error generating admin token")
}
@ -480,7 +488,7 @@ func (c *AdminClient) RemoveProvisioner(opts ...ProvisionerOption) error {
default:
return errors.New("must set either name or id in method options")
}
tok, err := c.generateAdminToken(u.Path)
tok, err := c.generateAdminToken(u)
if err != nil {
return errors.Wrapf(err, "error generating admin token")
}
@ -512,7 +520,7 @@ func (c *AdminClient) CreateProvisioner(prov *linkedca.Provisioner) (*linkedca.P
return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request")
}
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners")})
tok, err := c.generateAdminToken(u.Path)
tok, err := c.generateAdminToken(u)
if err != nil {
return nil, errors.Wrapf(err, "error generating admin token")
}
@ -548,7 +556,7 @@ func (c *AdminClient) UpdateProvisioner(name string, prov *linkedca.Provisioner)
return errs.Wrap(http.StatusInternalServerError, err, "error marshaling request")
}
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", name)})
tok, err := c.generateAdminToken(u.Path)
tok, err := c.generateAdminToken(u)
if err != nil {
return errors.Wrapf(err, "error generating admin token")
}
@ -587,7 +595,7 @@ func (c *AdminClient) GetExternalAccountKeysPaginate(provisionerName, reference
Path: p,
RawQuery: o.rawQuery(),
})
tok, err := c.generateAdminToken(u.Path)
tok, err := c.generateAdminToken(u)
if err != nil {
return nil, errors.Wrapf(err, "error generating admin token")
}
@ -623,7 +631,7 @@ func (c *AdminClient) CreateExternalAccountKey(provisionerName string, eakReques
return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request")
}
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "acme/eab/", provisionerName)})
tok, err := c.generateAdminToken(u.Path)
tok, err := c.generateAdminToken(u)
if err != nil {
return nil, errors.Wrapf(err, "error generating admin token")
}
@ -655,7 +663,7 @@ retry:
func (c *AdminClient) RemoveExternalAccountKey(provisionerName, keyID string) error {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "acme/eab", provisionerName, "/", keyID)})
tok, err := c.generateAdminToken(u.Path)
tok, err := c.generateAdminToken(u)
if err != nil {
return errors.Wrapf(err, "error generating admin token")
}

@ -92,6 +92,7 @@ func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Han
for _, s := range nonAuthenticatedPaths {
if strings.HasPrefix(r.URL.Path, s) || strings.HasPrefix(r.URL.Path, "/1.0"+s) {
next.ServeHTTP(w, r)
return
}
}
isMTLS := r.TLS != nil && len(r.TLS.PeerCertificates) > 0

@ -10,7 +10,6 @@ import (
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/hex"
"encoding/json"
"encoding/pem"
@ -116,7 +115,6 @@ type clientOptions struct {
x5cCertFile string
x5cCertStrs []string
x5cCert *x509.Certificate
x5cIssuer string
x5cSubject string
}
@ -294,18 +292,6 @@ func WithCertificate(cert tls.Certificate) ClientOption {
}
}
var (
stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64}
stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...)
)
type stepProvisionerASN1 struct {
Type int
Name []byte
CredentialID []byte
KeyValuePairs []string `asn1:"optional,omitempty"`
}
// WithAdminX5C will set the given file as the X5C certificate for use
// by the client.
func WithAdminX5C(certs []*x509.Certificate, key interface{}, passwordFile string) ClientOption {
@ -332,19 +318,13 @@ func WithAdminX5C(certs []*x509.Certificate, key interface{}, passwordFile strin
}
o.x5cCert = certs[0]
o.x5cSubject = o.x5cCert.Subject.CommonName
for _, e := range o.x5cCert.Extensions {
if e.Id.Equal(stepOIDProvisioner) {
var prov stepProvisionerASN1
if _, err := asn1.Unmarshal(e.Value, &prov); err != nil {
return errors.Wrap(err, "error unmarshaling provisioner OID from certificate")
}
o.x5cIssuer = string(prov.Name)
}
}
if o.x5cIssuer == "" {
return errors.New("provisioner extension not found in certificate")
switch leaf := certs[0]; {
case leaf.Subject.CommonName != "":
o.x5cSubject = leaf.Subject.CommonName
case len(leaf.DNSNames) > 0:
o.x5cSubject = leaf.DNSNames[0]
case len(leaf.EmailAddresses) > 0:
o.x5cSubject = leaf.EmailAddresses[0]
}
return nil

@ -8,6 +8,7 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/nosql"
"github.com/smallstep/nosql/database"
"golang.org/x/crypto/ssh"
@ -15,6 +16,7 @@ import (
var (
certsTable = []byte("x509_certs")
certsDataTable = []byte("x509_certs_data")
revokedCertsTable = []byte("revoked_x509_certs")
revokedSSHCertsTable = []byte("revoked_ssh_certs")
usedOTTTable = []byte("used_ott")
@ -82,7 +84,7 @@ func New(c *Config) (AuthDB, error) {
tables := [][]byte{
revokedCertsTable, certsTable, usedOTTTable,
sshCertsTable, sshHostsTable, sshHostPrincipalsTable, sshUsersTable,
revokedSSHCertsTable,
revokedSSHCertsTable, certsDataTable,
}
for _, b := range tables {
if err := db.CreateTable(b); err != nil {
@ -202,6 +204,19 @@ func (db *DB) GetCertificate(serialNumber string) (*x509.Certificate, error) {
return cert, nil
}
// GetCertificateData returns the data stored for a provisioner
func (db *DB) GetCertificateData(serialNumber string) (*CertificateData, error) {
b, err := db.Get(certsDataTable, []byte(serialNumber))
if err != nil {
return nil, errors.Wrap(err, "database Get error")
}
var data CertificateData
if err := json.Unmarshal(b, &data); err != nil {
return nil, errors.Wrap(err, "error unmarshaling json")
}
return &data, nil
}
// StoreCertificate stores a certificate PEM.
func (db *DB) StoreCertificate(crt *x509.Certificate) error {
if err := db.Set(certsTable, []byte(crt.SerialNumber.String()), crt.Raw); err != nil {
@ -210,6 +225,47 @@ func (db *DB) StoreCertificate(crt *x509.Certificate) error {
return nil
}
// CertificateData is the JSON representation of the data stored in
// x509_certs_data table.
type CertificateData struct {
Provisioner *ProvisionerData `json:"provisioner,omitempty"`
}
// ProvisionerData is the JSON representation of the provisioner stored in the
// x509_certs_data table.
type ProvisionerData struct {
ID string `json:"id"`
Name string `json:"name"`
Type string `json:"type"`
}
// StoreCertificateChain stores the leaf certificate and the provisioner that
// authorized the certificate.
func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Certificate) error {
leaf := chain[0]
serialNumber := []byte(leaf.SerialNumber.String())
data := &CertificateData{}
if p != nil {
data.Provisioner = &ProvisionerData{
ID: p.GetID(),
Name: p.GetName(),
Type: p.GetType().String(),
}
}
b, err := json.Marshal(data)
if err != nil {
return errors.Wrap(err, "error marshaling json")
}
// Add certificate and certificate data in one transaction.
tx := new(database.Tx)
tx.Set(certsTable, serialNumber, leaf.Raw)
tx.Set(certsDataTable, serialNumber, b)
if err := db.Update(tx); err != nil {
return errors.Wrap(err, "database Update error")
}
return nil
}
// UseToken returns true if we were able to successfully store the token for
// for the first time, false otherwise.
func (db *DB) UseToken(id, tok string) (bool, error) {
@ -304,6 +360,7 @@ type MockAuthDB struct {
MRevoke func(rci *RevokedCertificateInfo) error
MRevokeSSH func(rci *RevokedCertificateInfo) error
MGetCertificate func(serialNumber string) (*x509.Certificate, error)
MGetCertificateData func(serialNumber string) (*CertificateData, error)
MStoreCertificate func(crt *x509.Certificate) error
MUseToken func(id, tok string) (bool, error)
MIsSSHHost func(principal string) (bool, error)
@ -363,6 +420,17 @@ func (m *MockAuthDB) GetCertificate(serialNumber string) (*x509.Certificate, err
return m.Ret1.(*x509.Certificate), m.Err
}
// GetCertificateData mock.
func (m *MockAuthDB) GetCertificateData(serialNumber string) (*CertificateData, error) {
if m.MGetCertificateData != nil {
return m.MGetCertificateData(serialNumber)
}
if cd, ok := m.Ret1.(*CertificateData); ok {
return cd, m.Err
}
return nil, m.Err
}
// StoreCertificate mock.
func (m *MockAuthDB) StoreCertificate(crt *x509.Certificate) error {
if m.MStoreCertificate != nil {

@ -1,10 +1,15 @@
package db
import (
"crypto/x509"
"errors"
"math/big"
"reflect"
"testing"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/nosql"
"github.com/smallstep/nosql/database"
)
@ -158,3 +163,133 @@ func TestUseToken(t *testing.T) {
})
}
}
func TestDB_StoreCertificateChain(t *testing.T) {
p := &provisioner.JWK{
ID: "some-id",
Name: "admin",
Type: "JWK",
}
chain := []*x509.Certificate{
{Raw: []byte("the certificate"), SerialNumber: big.NewInt(1234)},
}
type fields struct {
DB nosql.DB
isUp bool
}
type args struct {
p provisioner.Interface
chain []*x509.Certificate
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{"ok", fields{&MockNoSQLDB{
MUpdate: func(tx *database.Tx) error {
if len(tx.Operations) != 2 {
t.Fatal("unexpected number of operations")
}
assert.Equals(t, []byte("x509_certs"), tx.Operations[0].Bucket)
assert.Equals(t, []byte("1234"), tx.Operations[0].Key)
assert.Equals(t, []byte("the certificate"), tx.Operations[0].Value)
assert.Equals(t, []byte("x509_certs_data"), tx.Operations[1].Bucket)
assert.Equals(t, []byte("1234"), tx.Operations[1].Key)
assert.Equals(t, []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"}}`), tx.Operations[1].Value)
return nil
},
}, true}, args{p, chain}, false},
{"ok no provisioner", fields{&MockNoSQLDB{
MUpdate: func(tx *database.Tx) error {
if len(tx.Operations) != 2 {
t.Fatal("unexpected number of operations")
}
assert.Equals(t, []byte("x509_certs"), tx.Operations[0].Bucket)
assert.Equals(t, []byte("1234"), tx.Operations[0].Key)
assert.Equals(t, []byte("the certificate"), tx.Operations[0].Value)
assert.Equals(t, []byte("x509_certs_data"), tx.Operations[1].Bucket)
assert.Equals(t, []byte("1234"), tx.Operations[1].Key)
assert.Equals(t, []byte(`{}`), tx.Operations[1].Value)
return nil
},
}, true}, args{nil, chain}, false},
{"fail store certificate", fields{&MockNoSQLDB{
MUpdate: func(tx *database.Tx) error {
return errors.New("test error")
},
}, true}, args{p, chain}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
d := &DB{
DB: tt.fields.DB,
isUp: tt.fields.isUp,
}
if err := d.StoreCertificateChain(tt.args.p, tt.args.chain...); (err != nil) != tt.wantErr {
t.Errorf("DB.StoreCertificateChain() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestDB_GetCertificateData(t *testing.T) {
type fields struct {
DB nosql.DB
isUp bool
}
type args struct {
serialNumber string
}
tests := []struct {
name string
fields fields
args args
want *CertificateData
wantErr bool
}{
{"ok", fields{&MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, []byte("x509_certs_data"))
assert.Equals(t, key, []byte("1234"))
return []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"}}`), nil
},
}, true}, args{"1234"}, &CertificateData{
Provisioner: &ProvisionerData{
ID: "some-id", Name: "admin", Type: "JWK",
},
}, false},
{"fail not found", fields{&MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, database.ErrNotFound
},
}, true}, args{"1234"}, nil, true},
{"fail db", fields{&MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, errors.New("an error")
},
}, true}, args{"1234"}, nil, true},
{"fail unmarshal", fields{&MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return []byte(`{"bad-json"}`), nil
},
}, true}, args{"1234"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &DB{
DB: tt.fields.DB,
isUp: tt.fields.isUp,
}
got, err := db.GetCertificateData(tt.args.serialNumber)
if (err != nil) != tt.wantErr {
t.Errorf("DB.GetCertificateData() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("DB.GetCertificateData() = %v, want %v", got, tt.want)
}
})
}
}

@ -40,7 +40,7 @@ require (
go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352
go.step.sm/cli-utils v0.7.0
go.step.sm/crypto v0.16.1
go.step.sm/linkedca v0.12.0
go.step.sm/linkedca v0.15.0
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd
google.golang.org/api v0.70.0

@ -804,8 +804,8 @@ go.step.sm/cli-utils v0.7.0/go.mod h1:Ur6bqA/yl636kCUJbp30J7Unv5JJ226eW2KqXPDwF/
go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0=
go.step.sm/crypto v0.16.1 h1:4mnZk21cSxyMGxsEpJwZKKvJvDu1PN09UVrWWFNUBdk=
go.step.sm/crypto v0.16.1/go.mod h1:3G0yQr5lQqfEG0CMYz8apC/qMtjLRQlzflL2AxkcN+g=
go.step.sm/linkedca v0.12.0 h1:FA18uJO5P6W2pklcezMs+w+N3dVbpKEE1LP9HLsJgg4=
go.step.sm/linkedca v0.12.0/go.mod h1:W59ucS4vFpuR0g4PtkGbbtXAwxbDEnNCg+ovkej1ANM=
go.step.sm/linkedca v0.15.0 h1:lEkGRDY+u7FudGKt8yEo7nBy5OzceO9s3rl+/sZVL5M=
go.step.sm/linkedca v0.15.0/go.mod h1:W59ucS4vFpuR0g4PtkGbbtXAwxbDEnNCg+ovkej1ANM=
go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=

Loading…
Cancel
Save