diff --git a/acme/account_test.go b/acme/account_test.go index 25600028..91327080 100644 --- a/acme/account_test.go +++ b/acme/account_test.go @@ -27,7 +27,7 @@ var ( } ) -func newProv() provisioner.Interface { +func newProv() Provisioner { // Initialize provisioners p := &provisioner.ACME{ Type: "ACME", diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index 9674b035..f8bac96c 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -244,7 +244,11 @@ func TestHandlerGetNonce(t *testing.T) { } func TestHandlerGetDirectory(t *testing.T) { - auth, err := acme.NewAuthority(new(db.MockNoSQLDB), "ca.smallstep.com", "acme", nil) + auth, err := acme.New(nil, acme.AuthorityOptions{ + DB: new(db.MockNoSQLDB), + DNS: "ca.smallstep.com", + Prefix: "acme", + }) assert.FatalError(t, err) prov := newProv() diff --git a/acme/api/middleware.go b/acme/api/middleware.go index 93a85a7f..f7d7dcf4 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -278,11 +278,12 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { api.WriteError(w, err) return } - if p.GetType() != provisioner.TypeACME { + acmeProv, ok := p.(*provisioner.ACME) + if !ok { api.WriteError(w, acme.AccountDoesNotExistErr(errors.New("provisioner must be of type ACME"))) return } - ctx = context.WithValue(ctx, acme.ProvisionerContextKey, p) + ctx = context.WithValue(ctx, acme.ProvisionerContextKey, acme.Provisioner(acmeProv)) next(w, r.WithContext(ctx)) } } diff --git a/acme/authority.go b/acme/authority.go index eaefca55..66bc1e00 100644 --- a/acme/authority.go +++ b/acme/authority.go @@ -47,11 +47,28 @@ type Interface interface { // Authority is the layer that handles all ACME interactions. type Authority struct { + backdate provisioner.Duration db nosql.DB dir *directory signAuth SignAuthority } +// AuthorityOptions required to create a new ACME Authority. +type AuthorityOptions struct { + Backdate provisioner.Duration + // DB is the database used by nosql. + DB nosql.DB + // DNS the host used to generate accurate ACME links. By default the authority + // will use the Host from the request, so this value will only be used if + // request.Host is empty. + DNS string + // Prefix is a URL path prefix under which the ACME api is served. This + // prefix is required to generate accurate ACME links. + // E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account -- + // "acme" is the prefix from which the ACME api is accessed. + Prefix string +} + var ( accountTable = []byte("acme_accounts") accountByKeyIDTable = []byte("acme_keyID_accountID_index") @@ -64,22 +81,34 @@ var ( ) // NewAuthority returns a new Authority that implements the ACME interface. +// +// Deprecated: NewAuthority exists for hitorical compatibility and should not +// be used. Use acme.New() instead. func NewAuthority(db nosql.DB, dns, prefix string, signAuth SignAuthority) (*Authority, error) { - if _, ok := db.(*database.SimpleDB); !ok { + return New(signAuth, AuthorityOptions{ + DB: db, + DNS: dns, + Prefix: prefix, + }) +} + +// New returns a new Autohrity that implements the ACME interface. +func New(signAuth SignAuthority, ops AuthorityOptions) (*Authority, error) { + if _, ok := ops.DB.(*database.SimpleDB); !ok { // If it's not a SimpleDB then go ahead and bootstrap the DB with the // necessary ACME tables. SimpleDB should ONLY be used for testing. tables := [][]byte{accountTable, accountByKeyIDTable, authzTable, challengeTable, nonceTable, orderTable, ordersByAccountIDTable, certTable} for _, b := range tables { - if err := db.CreateTable(b); err != nil { + if err := ops.DB.CreateTable(b); err != nil { return nil, errors.Wrapf(err, "error creating table %s", string(b)) } } } return &Authority{ - db: db, dir: newDirectory(dns, prefix), signAuth: signAuth, + backdate: ops.Backdate, db: ops.DB, dir: newDirectory(ops.DNS, ops.Prefix), signAuth: signAuth, }, nil } @@ -225,6 +254,12 @@ func (a *Authority) GetOrdersByAccount(ctx context.Context, id string) ([]string // NewOrder generates, stores, and returns a new ACME order. func (a *Authority) NewOrder(ctx context.Context, ops OrderOptions) (*Order, error) { + prov, err := ProvisionerFromContext(ctx) + if err != nil { + return nil, err + } + ops.backdate = a.backdate.Duration + ops.defaultDuration = prov.DefaultTLSCertDuration() order, err := newOrder(a.db, ops) if err != nil { return nil, Wrap(err, "error creating order") diff --git a/acme/authority_test.go b/acme/authority_test.go index aec022a3..e11b91db 100644 --- a/acme/authority_test.go +++ b/acme/authority_test.go @@ -925,10 +925,21 @@ func TestAuthorityNewOrder(t *testing.T) { type test struct { auth *Authority ops OrderOptions + ctx context.Context err *Error o **Order } tests := map[string]func(t *testing.T) test{ + "fail/no-provisioner": func(t *testing.T) test { + auth, err := NewAuthority(&db.MockNoSQLDB{}, "ca.smallstep.com", "acme", nil) + assert.FatalError(t, err) + return test{ + auth: auth, + ops: defaultOrderOps(), + ctx: context.Background(), + err: ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, "fail/newOrder-error": func(t *testing.T) test { auth, err := NewAuthority(&db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { @@ -939,6 +950,7 @@ func TestAuthorityNewOrder(t *testing.T) { return test{ auth: auth, ops: defaultOrderOps(), + ctx: ctx, err: ServerInternalErr(errors.New("error creating order: error creating http challenge: error saving acme challenge: force")), } }, @@ -993,6 +1005,7 @@ func TestAuthorityNewOrder(t *testing.T) { return test{ auth: auth, ops: defaultOrderOps(), + ctx: ctx, o: acmeO, } }, @@ -1000,7 +1013,7 @@ func TestAuthorityNewOrder(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if acmeO, err := tc.auth.NewOrder(ctx, tc.ops); err != nil { + if acmeO, err := tc.auth.NewOrder(tc.ctx, tc.ops); err != nil { if assert.NotNil(t, tc.err) { ae, ok := err.(*Error) assert.True(t, ok) @@ -1160,10 +1173,21 @@ func TestAuthorityFinalizeOrder(t *testing.T) { type test struct { auth *Authority id, accID string + ctx context.Context err *Error o *order } tests := map[string]func(t *testing.T) test{ + "fail/no-provisioner": func(t *testing.T) test { + auth, err := NewAuthority(&db.MockNoSQLDB{}, "ca.smallstep.com", "acme", nil) + assert.FatalError(t, err) + return test{ + auth: auth, + id: "foo", + ctx: context.Background(), + err: ServerInternalErr(errors.New("provisioner expected in request context")), + } + }, "fail/getOrder-error": func(t *testing.T) test { id := "foo" auth, err := NewAuthority(&db.MockNoSQLDB{ @@ -1177,6 +1201,7 @@ func TestAuthorityFinalizeOrder(t *testing.T) { return test{ auth: auth, id: id, + ctx: ctx, err: ServerInternalErr(errors.New("error loading order foo: force")), } }, @@ -1197,6 +1222,7 @@ func TestAuthorityFinalizeOrder(t *testing.T) { auth: auth, id: o.ID, accID: "foo", + ctx: ctx, err: UnauthorizedErr(errors.New("account does not own order")), } }, @@ -1223,6 +1249,7 @@ func TestAuthorityFinalizeOrder(t *testing.T) { auth: auth, id: o.ID, accID: o.AccountID, + ctx: ctx, err: ServerInternalErr(errors.New("error finalizing order: error storing order: force")), } }, @@ -1245,6 +1272,7 @@ func TestAuthorityFinalizeOrder(t *testing.T) { auth: auth, id: o.ID, accID: o.AccountID, + ctx: ctx, o: o, } }, @@ -1252,7 +1280,7 @@ func TestAuthorityFinalizeOrder(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if acmeO, err := tc.auth.FinalizeOrder(ctx, tc.accID, tc.id, nil); err != nil { + if acmeO, err := tc.auth.FinalizeOrder(tc.ctx, tc.accID, tc.id, nil); err != nil { if assert.NotNil(t, tc.err) { ae, ok := err.(*Error) assert.True(t, ok) diff --git a/acme/common.go b/acme/common.go index 8b878016..d8b2b7e4 100644 --- a/acme/common.go +++ b/acme/common.go @@ -12,6 +12,47 @@ import ( "github.com/smallstep/cli/jose" ) +// Provisioner is an interface that implements a subset of the provisioner.Interface -- +// only those methods required by the ACME api/authority. +type Provisioner interface { + AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) + GetName() string + DefaultTLSCertDuration() time.Duration +} + +// MockProvisioner for testing +type MockProvisioner struct { + Mret1 interface{} + Merr error + MgetName func() string + MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error) + MdefaultTLSCertDuration func() time.Duration +} + +// GetName mock +func (m *MockProvisioner) GetName() string { + if m.MgetName != nil { + return m.MgetName() + } + return m.Mret1.(string) +} + +// AuthorizeSign mock +func (m *MockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) { + if m.MauthorizeSign != nil { + return m.MauthorizeSign(ctx, ott) + } + return m.Mret1.([]provisioner.SignOption), m.Merr +} + +// DefaultTLSCertDuration mock +func (m *MockProvisioner) DefaultTLSCertDuration() time.Duration { + if m.MdefaultTLSCertDuration != nil { + return m.MdefaultTLSCertDuration() + } + return m.Mret1.(time.Duration) +} + // ContextKey is the key type for storing and searching for ACME request // essentials in the context of a request. type ContextKey string @@ -70,12 +111,16 @@ func JwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) { // ProvisionerFromContext searches the context for a provisioner. Returns the // provisioner or an error. -func ProvisionerFromContext(ctx context.Context) (provisioner.Interface, error) { - val, ok := ctx.Value(ProvisionerContextKey).(provisioner.Interface) - if !ok || val == nil { +func ProvisionerFromContext(ctx context.Context) (Provisioner, error) { + val := ctx.Value(ProvisionerContextKey) + if val == nil { return nil, ServerInternalErr(errors.Errorf("provisioner expected in request context")) } - return val, nil + pval, ok := val.(Provisioner) + if !ok || pval == nil { + return nil, ServerInternalErr(errors.Errorf("provisioner in context is not an ACME provisioner")) + } + return pval, nil } // SignAuthority is the interface implemented by a CA authority. diff --git a/acme/order.go b/acme/order.go index 3f02bc51..839af337 100644 --- a/acme/order.go +++ b/acme/order.go @@ -45,10 +45,12 @@ func (o *Order) GetID() string { // OrderOptions options with which to create a new Order. type OrderOptions struct { - AccountID string `json:"accID"` - Identifiers []Identifier `json:"identifiers"` - NotBefore time.Time `json:"notBefore"` - NotAfter time.Time `json:"notAfter"` + AccountID string `json:"accID"` + Identifiers []Identifier `json:"identifiers"` + NotBefore time.Time `json:"notBefore"` + NotAfter time.Time `json:"notAfter"` + backdate time.Duration + defaultDuration time.Duration } type order struct { @@ -82,6 +84,17 @@ func newOrder(db nosql.DB, ops OrderOptions) (*order, error) { } now := clock.Now() + var backdate time.Duration + nbf := ops.NotBefore + if nbf.IsZero() { + nbf = now + backdate = -1 * ops.backdate + } + naf := ops.NotAfter + if naf.IsZero() { + naf = nbf.Add(ops.defaultDuration) + } + o := &order{ ID: id, AccountID: ops.AccountID, @@ -89,8 +102,8 @@ func newOrder(db nosql.DB, ops OrderOptions) (*order, error) { Status: StatusPending, Expires: now.Add(defaultOrderExpiry), Identifiers: ops.Identifiers, - NotBefore: ops.NotBefore, - NotAfter: ops.NotAfter, + NotBefore: nbf.Add(backdate), + NotAfter: naf, Authorizations: authzs, } if err := o.save(db, nil); err != nil { @@ -236,7 +249,7 @@ func (o *order) updateStatus(db nosql.DB) (*order, error) { // finalize signs a certificate if the necessary conditions for Order completion // have been met. -func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAuthority, p provisioner.Interface) (*order, error) { +func (o *order) finalize(db nosql.DB, csr *x509.CertificateRequest, auth SignAuthority, p Provisioner) (*order, error) { var err error if o, err = o.updateStatus(db); err != nil { return nil, err diff --git a/acme/order_test.go b/acme/order_test.go index 2ac68657..86a4eb32 100644 --- a/acme/order_test.go +++ b/acme/order_test.go @@ -309,7 +309,7 @@ func TestOrderSave(t *testing.T) { } } -func TestNewOrder(t *testing.T) { +func Test_newOrder(t *testing.T) { type test struct { ops OrderOptions db nosql.DB @@ -436,6 +436,49 @@ func TestNewOrder(t *testing.T) { authzs: authzs, } }, + "ok/validity-bounds-not-set": func(t *testing.T) test { + count := 0 + oids := []string{"1", "2", "3"} + oidsB, err := json.Marshal(oids) + assert.FatalError(t, err) + authzs := &([]string{}) + var ( + _oid = "" + oid = &_oid + ) + ops := defaultOrderOps() + ops.backdate = time.Minute + ops.defaultDuration = 12 * time.Hour + ops.NotBefore = time.Time{} + ops.NotAfter = time.Time{} + return test{ + ops: ops, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + if count >= 9 { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte(ops.AccountID)) + assert.Equals(t, old, oidsB) + newB, err := json.Marshal(append(oids, *oid)) + assert.FatalError(t, err) + assert.Equals(t, newval, newB) + } else if count == 8 { + *oid = string(key) + } else if count == 7 { + *authzs = append(*authzs, string(key)) + } else if count == 3 { + *authzs = []string{string(key)} + } + count++ + return nil, true, nil + }, + MGet: func(bucket, key []byte) ([]byte, error) { + return oidsB, nil + }, + }, + authzs: authzs, + } + }, } for name, run := range tests { tc := run(t) @@ -465,8 +508,21 @@ func TestNewOrder(t *testing.T) { assert.True(t, o.Expires.Before(expiry.Add(time.Minute))) assert.True(t, o.Expires.After(expiry.Add(-1*time.Minute))) - assert.Equals(t, o.NotBefore, tc.ops.NotBefore) - assert.Equals(t, o.NotAfter, tc.ops.NotAfter) + nbf := tc.ops.NotBefore + now := time.Now().UTC() + if !tc.ops.NotBefore.IsZero() { + assert.Equals(t, o.NotBefore, tc.ops.NotBefore) + } else { + nbf = o.NotBefore.Add(tc.ops.backdate) + assert.True(t, o.NotBefore.Before(now.Add(-tc.ops.backdate+time.Second))) + assert.True(t, o.NotBefore.Add(tc.ops.backdate+2*time.Second).After(now)) + } + if !tc.ops.NotAfter.IsZero() { + assert.Equals(t, o.NotAfter, tc.ops.NotAfter) + } else { + naf := nbf.Add(tc.ops.defaultDuration) + assert.Equals(t, o.NotAfter, naf) + } } } }) @@ -861,7 +917,7 @@ func TestOrderFinalize(t *testing.T) { db nosql.DB csr *x509.CertificateRequest sa SignAuthority - prov provisioner.Interface + prov Provisioner } tests := map[string]func(t *testing.T) test{ "fail/already-invalid": func(t *testing.T) test { @@ -1008,7 +1064,7 @@ func TestOrderFinalize(t *testing.T) { o: o, csr: csr, err: ServerInternalErr(errors.New("error retrieving authorization options from ACME provisioner: force")), - prov: &provisioner.MockProvisioner{ + prov: &MockProvisioner{ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { return nil, errors.New("force") }, diff --git a/authority/provisioner/acme.go b/authority/provisioner/acme.go index 95115e6d..88189d01 100644 --- a/authority/provisioner/acme.go +++ b/authority/provisioner/acme.go @@ -3,6 +3,7 @@ package provisioner import ( "context" "crypto/x509" + "time" "github.com/pkg/errors" "github.com/smallstep/certificates/errs" @@ -44,6 +45,12 @@ func (p *ACME) GetEncryptedKey() (string, string, bool) { return "", "", false } +// DefaultTLSCertDuration returns the default TLS cert duration enforced by +// the provisioner. +func (p *ACME) DefaultTLSCertDuration() time.Duration { + return p.claimer.DefaultTLSCertDuration() +} + // Init initializes and validates the fields of a JWK type. func (p *ACME) Init(config Config) (err error) { switch { diff --git a/ca/ca.go b/ca/ca.go index 96bebba4..3c57b759 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -124,7 +124,12 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) { } prefix := "acme" - acmeAuth, err := acme.NewAuthority(auth.GetDatabase().(nosql.DB), dns, prefix, auth) + acmeAuth, err := acme.New(auth, acme.AuthorityOptions{ + Backdate: *config.AuthorityConfig.Backdate, + DB: auth.GetDatabase().(nosql.DB), + DNS: dns, + Prefix: prefix, + }) if err != nil { return nil, errors.Wrap(err, "error creating ACME authority") }