diff --git a/acme/api/account.go b/acme/api/account.go index f6e18f90..710747ca 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -67,8 +67,11 @@ func (u *UpdateAccountRequest) Validate() error { } // NewAccount is the handler resource for creating new ACME accounts. -func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { +func NewAccount(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + payload, err := payloadFromContext(ctx) if err != nil { render.Error(w, err) @@ -114,7 +117,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { return } - eak, err := h.validateExternalAccountBinding(ctx, &nar) + eak, err := validateExternalAccountBinding(ctx, &nar) if err != nil { render.Error(w, err) return @@ -125,7 +128,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { Contact: nar.Contact, Status: acme.StatusValid, } - if err := h.db.CreateAccount(ctx, acc); err != nil { + if err := db.CreateAccount(ctx, acc); err != nil { render.Error(w, acme.WrapErrorISE(err, "error creating account")) return } @@ -135,7 +138,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { render.Error(w, err) return } - if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil { + if err := db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil { render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key")) return } @@ -146,15 +149,18 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { httpStatus = http.StatusOK } - h.linker.LinkAccount(ctx, acc) + linker.LinkAccount(ctx, acc) - w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID)) + w.Header().Set("Location", linker.GetLink(r.Context(), acme.AccountLinkType, acc.ID)) render.JSONStatus(w, acc, httpStatus) } // GetOrUpdateAccount is the api for updating an ACME account. -func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { +func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) @@ -186,16 +192,16 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { acc.Contact = uar.Contact } - if err := h.db.UpdateAccount(ctx, acc); err != nil { + if err := db.UpdateAccount(ctx, acc); err != nil { render.Error(w, acme.WrapErrorISE(err, "error updating account")) return } } } - h.linker.LinkAccount(ctx, acc) + linker.LinkAccount(ctx, acc) - w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID)) + w.Header().Set("Location", linker.GetLink(ctx, acme.AccountLinkType, acc.ID)) render.JSON(w, acc) } @@ -209,8 +215,11 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) { } // GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account. -func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { +func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) @@ -221,13 +230,14 @@ func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID)) return } - orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID) + + orders, err := db.GetOrdersByAccountID(ctx, acc.ID) if err != nil { render.Error(w, err) return } - h.linker.LinkOrdersByAccountID(ctx, orders) + linker.LinkOrdersByAccountID(ctx, orders) render.JSON(w, orders) logOrdersByAccount(w, orders) diff --git a/acme/api/account_test.go b/acme/api/account_test.go index a0161cb4..d81553d2 100644 --- a/acme/api/account_test.go +++ b/acme/api/account_test.go @@ -31,6 +31,22 @@ var ( } ) +type fakeProvisioner struct{} + +func (*fakeProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error { + return nil +} + +func (*fakeProvisioner) AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) { + return nil, nil +} + +func (*fakeProvisioner) AuthorizeRevoke(ctx context.Context, token string) error { return nil } +func (*fakeProvisioner) GetID() string { return "" } +func (*fakeProvisioner) GetName() string { return "" } +func (*fakeProvisioner) DefaultTLSCertDuration() time.Duration { return 0 } +func (*fakeProvisioner) GetOptions() *provisioner.Options { return nil } + func newProv() acme.Provisioner { // Initialize provisioners p := &provisioner.ACME{ @@ -320,10 +336,9 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: accID} - ctx := context.WithValue(context.Background(), accContextKey, acc) - ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = acme.NewProvisionerContext(ctx, prov) + ctx = context.WithValue(ctx, accContextKey, acc) return test{ db: &acme.MockDB{ MockGetOrdersByAccountID: func(ctx context.Context, id string) ([]string, error) { @@ -339,11 +354,11 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} + ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.GetOrdersByAccountID(w, req) + GetOrdersByAccountID(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -387,6 +402,7 @@ func TestHandler_NewAccount(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-payload": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.Background(), statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), @@ -395,6 +411,7 @@ func TestHandler_NewAccount(t *testing.T) { "fail/nil-payload": func(t *testing.T) test { ctx := context.WithValue(context.Background(), payloadContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), @@ -403,6 +420,7 @@ func TestHandler_NewAccount(t *testing.T) { "fail/unmarshal-payload-error": func(t *testing.T) test { ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "failed to "+ @@ -417,6 +435,7 @@ func TestHandler_NewAccount(t *testing.T) { assert.FatalError(t, err) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), @@ -429,8 +448,9 @@ func TestHandler_NewAccount(t *testing.T) { b, err := json.Marshal(nar) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -442,9 +462,10 @@ func TestHandler_NewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jwk expected in request context"), @@ -456,10 +477,11 @@ func TestHandler_NewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jwk expected in request context"), @@ -478,9 +500,9 @@ func TestHandler_NewAccount(t *testing.T) { prov.RequireEAB = true ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorExternalAccountRequiredType, "no external account binding provided"), @@ -495,7 +517,7 @@ func TestHandler_NewAccount(t *testing.T) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwkContextKey, jwk) return test{ db: &acme.MockDB{ @@ -525,18 +547,11 @@ func TestHandler_NewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - scepProvisioner := &provisioner.SCEP{ - Type: "SCEP", - Name: "test@scep-provisioner.com", - } - if err := scepProvisioner.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil { - assert.FatalError(t, err) - } ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner) + ctx = acme.NewProvisionerContext(ctx, &fakeProvisioner{}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewError(acme.ErrorServerInternalType, "provisioner in context is not an ACME provisioner"), @@ -575,8 +590,7 @@ func TestHandler_NewAccount(t *testing.T) { prov.RequireEAB = true ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) eak := &acme.ExternalAccountKey{ ID: "eakID", @@ -623,8 +637,7 @@ func TestHandler_NewAccount(t *testing.T) { assert.FatalError(t, err) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ db: &acme.MockDB{ MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { @@ -659,11 +672,11 @@ func TestHandler_NewAccount(t *testing.T) { Status: acme.StatusValid, Contact: []string{"foo", "bar"}, } - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ + db: &acme.MockDB{}, ctx: ctx, acc: acc, statusCode: 200, @@ -688,8 +701,7 @@ func TestHandler_NewAccount(t *testing.T) { prov.RequireEAB = false ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ db: &acme.MockDB{ MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { @@ -743,8 +755,7 @@ func TestHandler_NewAccount(t *testing.T) { prov.RequireEAB = true ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -783,11 +794,11 @@ func TestHandler_NewAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} + ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) req := httptest.NewRequest("GET", "/foo/bar", nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.NewAccount(w, req) + NewAccount(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -838,6 +849,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.Background(), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -846,6 +858,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { "fail/nil-account": func(t *testing.T) test { ctx := context.WithValue(context.Background(), accContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -854,6 +867,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { "fail/no-payload": func(t *testing.T) test { ctx := context.WithValue(context.Background(), accContextKey, &acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), @@ -863,6 +877,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), @@ -872,6 +887,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-account request payload: unexpected end of JSON input"), @@ -886,6 +902,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), @@ -918,10 +935,9 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { } b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ db: &acme.MockDB{ MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { @@ -938,11 +954,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { uar := &UpdateAccountRequest{} b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 200, } @@ -953,10 +969,9 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { } b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ db: &acme.MockDB{ MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { @@ -970,11 +985,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { } }, "ok/post-as-get": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 200, } @@ -983,11 +998,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} + ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) req := httptest.NewRequest("GET", "/foo/bar", nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.GetOrUpdateAccount(w, req) + GetOrUpdateAccount(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) diff --git a/acme/api/eab.go b/acme/api/eab.go index 84be6453..cf4f1993 100644 --- a/acme/api/eab.go +++ b/acme/api/eab.go @@ -17,7 +17,7 @@ type ExternalAccountBinding struct { } // validateExternalAccountBinding validates the externalAccountBinding property in a call to new-account. -func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) { +func validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) { acmeProv, err := acmeProvisionerFromContext(ctx) if err != nil { return nil, acme.WrapErrorISE(err, "could not load ACME provisioner from context") @@ -48,7 +48,8 @@ func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAc return nil, acmeErr } - externalAccountKey, err := h.db.GetExternalAccountKey(ctx, acmeProv.ID, keyID) + db := acme.MustDatabaseFromContext(ctx) + externalAccountKey, err := db.GetExternalAccountKey(ctx, acmeProv.ID, keyID) if err != nil { if _, ok := err.(*acme.Error); ok { return nil, acme.WrapError(acme.ErrorUnauthorizedType, err, "the field 'kid' references an unknown key") @@ -111,7 +112,6 @@ func keysAreEqual(x, y *jose.JSONWebKey) bool { // o The "nonce" field MUST NOT be present // o The "url" field MUST be set to the same value as the outer JWS func validateEABJWS(ctx context.Context, jws *jose.JSONWebSignature) (string, *acme.Error) { - if jws == nil { return "", acme.NewErrorISE("no JWS provided") } diff --git a/acme/api/eab_test.go b/acme/api/eab_test.go index c2725588..d2e596f9 100644 --- a/acme/api/eab_test.go +++ b/acme/api/eab_test.go @@ -14,7 +14,6 @@ import ( "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/authority/provisioner" ) func Test_keysAreEqual(t *testing.T) { @@ -100,8 +99,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { assert.FatalError(t, err) prov := newACMEProv(t) ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ db: &acme.MockDB{}, ctx: ctx, @@ -145,8 +143,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) createdAt := time.Now() return test{ @@ -191,17 +188,10 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - scepProvisioner := &provisioner.SCEP{ - Type: "SCEP", - Name: "test@scep-provisioner.com", - } - if err := scepProvisioner.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil { - assert.FatalError(t, err) - } + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner) + ctx = acme.NewProvisionerContext(ctx, &fakeProvisioner{}) return test{ ctx: ctx, err: acme.NewError(acme.ErrorServerInternalType, "could not load ACME provisioner from context: provisioner in context is not an ACME provisioner"), @@ -220,8 +210,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ db: &acme.MockDB{}, ctx: ctx, @@ -266,8 +255,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{}, @@ -312,8 +300,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -360,8 +347,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -410,8 +396,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -460,8 +445,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -510,8 +494,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) createdAt := time.Now() return test{ @@ -568,8 +551,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -616,8 +598,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) createdAt := time.Now() boundAt := time.Now().Add(1 * time.Second) @@ -676,8 +657,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -734,8 +714,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, jwk) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -789,8 +768,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { assert.FatalError(t, err) prov := newACMEProv(t) prov.RequireEAB = true - ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -845,8 +823,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { prov := newACMEProv(t) prov.RequireEAB = true ctx := context.WithValue(context.Background(), jwkContextKey, nil) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ db: &acme.MockDB{ @@ -873,10 +850,8 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - db: tc.db, - } - got, err := h.validateExternalAccountBinding(tc.ctx, tc.nar) + ctx := acme.NewDatabaseContext(tc.ctx, tc.db) + got, err := validateExternalAccountBinding(ctx, tc.nar) wantErr := tc.err != nil gotErr := err != nil if wantErr != gotErr { diff --git a/acme/api/handler.go b/acme/api/handler.go index 10eb22cb..2e3931b1 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -2,12 +2,10 @@ package api import ( "context" - "crypto/tls" "crypto/x509" "encoding/json" "encoding/pem" "fmt" - "net" "net/http" "time" @@ -16,6 +14,7 @@ import ( "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" ) @@ -39,111 +38,152 @@ type payloadInfo struct { isEmptyJSON bool } -// Handler is the ACME API request handler. -type Handler struct { - db acme.DB - backdate provisioner.Duration - ca acme.CertificateAuthority - linker Linker - validateChallengeOptions *acme.ValidateChallengeOptions - prerequisitesChecker func(ctx context.Context) (bool, error) -} - // HandlerOptions required to create a new ACME API request handler. type HandlerOptions struct { - Backdate provisioner.Duration - // DB storage backend that impements the acme.DB interface. + // DB storage backend that implements the acme.DB interface. + // + // Deprecated: use acme.NewContex(context.Context, acme.DB) DB acme.DB + + // CA is the certificate authority interface. + // + // Deprecated: use authority.NewContext(context.Context, *authority.Authority) + CA acme.CertificateAuthority + + // Backdate is the duration that the CA will subtract from the current time + // to set the NotBefore in the certificate. + Backdate provisioner.Duration + // DNS the host used to generate accurate ACME links. By default the authority // will use the Host from the request, so this value will only be used if // request.Host is empty. DNS string + // Prefix is a URL path prefix under which the ACME api is served. This // prefix is required to generate accurate ACME links. // E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account -- // "acme" is the prefix from which the ACME api is accessed. Prefix string - CA acme.CertificateAuthority + // PrerequisitesChecker checks if all prerequisites for serving ACME are // met by the CA configuration. PrerequisitesChecker func(ctx context.Context) (bool, error) } +var mustAuthority = func(ctx context.Context) acme.CertificateAuthority { + return authority.MustFromContext(ctx) +} + +// handler is the ACME API request handler. +type handler struct { + opts *HandlerOptions +} + +// Route traffic and implement the Router interface. For backward compatibility +// this route adds will add a new middleware that will set the ACME components +// on the context. +// +// Note: this method is deprecated in step-ca, other applications can still use +// this to support ACME, but the recommendation is to use use +// api.Route(api.Router) and acme.NewContext() instead. +func (h *handler) Route(r api.Router) { + client := acme.NewClient() + linker := acme.NewLinker(h.opts.DNS, h.opts.Prefix) + route(r, func(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if ca, ok := h.opts.CA.(*authority.Authority); ok && ca != nil { + ctx = authority.NewContext(ctx, ca) + } + ctx = acme.NewContext(ctx, h.opts.DB, client, linker, h.opts.PrerequisitesChecker) + next(w, r.WithContext(ctx)) + } + }) +} + // NewHandler returns a new ACME API handler. -func NewHandler(ops HandlerOptions) api.RouterHandler { - transport := &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - } - client := http.Client{ - Timeout: 30 * time.Second, - Transport: transport, - } - dialer := &net.Dialer{ - Timeout: 30 * time.Second, - } - prerequisitesChecker := func(ctx context.Context) (bool, error) { - // by default all prerequisites are met - return true, nil - } - if ops.PrerequisitesChecker != nil { - prerequisitesChecker = ops.PrerequisitesChecker - } - return &Handler{ - ca: ops.CA, - db: ops.DB, - backdate: ops.Backdate, - linker: NewLinker(ops.DNS, ops.Prefix), - validateChallengeOptions: &acme.ValidateChallengeOptions{ - HTTPGet: client.Get, - LookupTxt: net.LookupTXT, - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { - return tls.DialWithDialer(dialer, network, addr, config) - }, - }, - prerequisitesChecker: prerequisitesChecker, +// +// Note: this method is deprecated in step-ca, other applications can still use +// this to support ACME, but the recommendation is to use use +// api.Route(api.Router) and acme.NewContext() instead. +func NewHandler(opts HandlerOptions) api.RouterHandler { + return &handler{ + opts: &opts, } } -// Route traffic and implement the Router interface. -func (h *Handler) Route(r api.Router) { - getPath := h.linker.GetUnescapedPathSuffix - // Standard ACME API - r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce)))))) - r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce)))))) - r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory)))) - r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory)))) +// Route traffic and implement the Router interface. This method requires that +// all the acme components, authority, db, client, linker, and prerequisite +// checker to be present in the context. +func Route(r api.Router) { + route(r, nil) +} +func route(r api.Router, middleware func(next nextHTTP) nextHTTP) { + commonMiddleware := func(next nextHTTP) nextHTTP { + handler := func(w http.ResponseWriter, r *http.Request) { + // Linker middleware gets the provisioner and current url from the + // request and sets them in the context. + linker := acme.MustLinkerFromContext(r.Context()) + linker.Middleware(http.HandlerFunc(checkPrerequisites(next))).ServeHTTP(w, r) + } + if middleware != nil { + handler = middleware(handler) + } + return handler + } validatingMiddleware := func(next nextHTTP) nextHTTP { - return h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next)))))))) + return commonMiddleware(addNonce(addDirLink(verifyContentType(parseJWS(validateJWS(next)))))) } extractPayloadByJWK := func(next nextHTTP) nextHTTP { - return validatingMiddleware(h.extractJWK(h.verifyAndExtractJWSPayload(next))) + return validatingMiddleware(extractJWK(verifyAndExtractJWSPayload(next))) } extractPayloadByKid := func(next nextHTTP) nextHTTP { - return validatingMiddleware(h.lookupJWK(h.verifyAndExtractJWSPayload(next))) + return validatingMiddleware(lookupJWK(verifyAndExtractJWSPayload(next))) } extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP { - return validatingMiddleware(h.extractOrLookupJWK(h.verifyAndExtractJWSPayload(next))) + return validatingMiddleware(extractOrLookupJWK(verifyAndExtractJWSPayload(next))) } - r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), extractPayloadByJWK(h.NewAccount)) - r.MethodFunc("POST", getPath(AccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.GetOrUpdateAccount)) - r.MethodFunc("POST", getPath(KeyChangeLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.NotImplemented)) - r.MethodFunc("POST", getPath(NewOrderLinkType, "{provisionerID}"), extractPayloadByKid(h.NewOrder)) - r.MethodFunc("POST", getPath(OrderLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder))) - r.MethodFunc("POST", getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccountID))) - r.MethodFunc("POST", getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.FinalizeOrder)) - r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization))) - r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge)) - r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate))) - r.MethodFunc("POST", getPath(RevokeCertLinkType, "{provisionerID}"), extractPayloadByKidOrJWK(h.RevokeCert)) + getPath := acme.GetUnescapedPathSuffix + + // Standard ACME API + r.MethodFunc("GET", getPath(acme.NewNonceLinkType, "{provisionerID}"), + commonMiddleware(addNonce(addDirLink(GetNonce)))) + r.MethodFunc("HEAD", getPath(acme.NewNonceLinkType, "{provisionerID}"), + commonMiddleware(addNonce(addDirLink(GetNonce)))) + r.MethodFunc("GET", getPath(acme.DirectoryLinkType, "{provisionerID}"), + commonMiddleware(GetDirectory)) + r.MethodFunc("HEAD", getPath(acme.DirectoryLinkType, "{provisionerID}"), + commonMiddleware(GetDirectory)) + + r.MethodFunc("POST", getPath(acme.NewAccountLinkType, "{provisionerID}"), + extractPayloadByJWK(NewAccount)) + r.MethodFunc("POST", getPath(acme.AccountLinkType, "{provisionerID}", "{accID}"), + extractPayloadByKid(GetOrUpdateAccount)) + r.MethodFunc("POST", getPath(acme.KeyChangeLinkType, "{provisionerID}", "{accID}"), + extractPayloadByKid(NotImplemented)) + r.MethodFunc("POST", getPath(acme.NewOrderLinkType, "{provisionerID}"), + extractPayloadByKid(NewOrder)) + r.MethodFunc("POST", getPath(acme.OrderLinkType, "{provisionerID}", "{ordID}"), + extractPayloadByKid(isPostAsGet(GetOrder))) + r.MethodFunc("POST", getPath(acme.OrdersByAccountLinkType, "{provisionerID}", "{accID}"), + extractPayloadByKid(isPostAsGet(GetOrdersByAccountID))) + r.MethodFunc("POST", getPath(acme.FinalizeLinkType, "{provisionerID}", "{ordID}"), + extractPayloadByKid(FinalizeOrder)) + r.MethodFunc("POST", getPath(acme.AuthzLinkType, "{provisionerID}", "{authzID}"), + extractPayloadByKid(isPostAsGet(GetAuthorization))) + r.MethodFunc("POST", getPath(acme.ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), + extractPayloadByKid(GetChallenge)) + r.MethodFunc("POST", getPath(acme.CertificateLinkType, "{provisionerID}", "{certID}"), + extractPayloadByKid(isPostAsGet(GetCertificate))) + r.MethodFunc("POST", getPath(acme.RevokeCertLinkType, "{provisionerID}"), + extractPayloadByKidOrJWK(RevokeCert)) } // GetNonce just sets the right header since a Nonce is added to each response // by middleware by default. -func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) { +func GetNonce(w http.ResponseWriter, r *http.Request) { if r.Method == "HEAD" { w.WriteHeader(http.StatusOK) } else { @@ -179,7 +219,7 @@ func (d *Directory) ToLog() (interface{}, error) { // GetDirectory is the ACME resource for returning a directory configuration // for client configuration. -func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { +func GetDirectory(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acmeProv, err := acmeProvisionerFromContext(ctx) if err != nil { @@ -187,12 +227,13 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { return } + linker := acme.MustLinkerFromContext(ctx) render.JSON(w, &Directory{ - NewNonce: h.linker.GetLink(ctx, NewNonceLinkType), - NewAccount: h.linker.GetLink(ctx, NewAccountLinkType), - NewOrder: h.linker.GetLink(ctx, NewOrderLinkType), - RevokeCert: h.linker.GetLink(ctx, RevokeCertLinkType), - KeyChange: h.linker.GetLink(ctx, KeyChangeLinkType), + NewNonce: linker.GetLink(ctx, acme.NewNonceLinkType), + NewAccount: linker.GetLink(ctx, acme.NewAccountLinkType), + NewOrder: linker.GetLink(ctx, acme.NewOrderLinkType), + RevokeCert: linker.GetLink(ctx, acme.RevokeCertLinkType), + KeyChange: linker.GetLink(ctx, acme.KeyChangeLinkType), Meta: Meta{ ExternalAccountRequired: acmeProv.RequireEAB, }, @@ -201,19 +242,22 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { // NotImplemented returns a 501 and is generally a placeholder for functionality which // MAY be added at some point in the future but is not in any way a guarantee of such. -func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) { +func NotImplemented(w http.ResponseWriter, r *http.Request) { render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) } // GetAuthorization ACME api for retrieving an Authz. -func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) { +func GetAuthorization(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) return } - az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID")) + az, err := db.GetAuthorization(ctx, chi.URLParam(r, "authzID")) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization")) return @@ -223,20 +267,23 @@ func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) { "account '%s' does not own authorization '%s'", acc.ID, az.ID)) return } - if err = az.UpdateStatus(ctx, h.db); err != nil { + if err = az.UpdateStatus(ctx, db); err != nil { render.Error(w, acme.WrapErrorISE(err, "error updating authorization status")) return } - h.linker.LinkAuthorization(ctx, az) + linker.LinkAuthorization(ctx, az) - w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID)) + w.Header().Set("Location", linker.GetLink(ctx, acme.AuthzLinkType, az.ID)) render.JSON(w, az) } // GetChallenge ACME api for retrieving a Challenge. -func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { +func GetChallenge(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) @@ -257,7 +304,7 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { // we'll just ignore the body. azID := chi.URLParam(r, "authzID") - ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID) + ch, err := db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge")) return @@ -273,29 +320,31 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { render.Error(w, err) return } - if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil { + if err = ch.Validate(ctx, db, jwk); err != nil { render.Error(w, acme.WrapErrorISE(err, "error validating challenge")) return } - h.linker.LinkChallenge(ctx, ch, azID) + linker.LinkChallenge(ctx, ch, azID) - w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up")) - w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID)) + w.Header().Add("Link", link(linker.GetLink(ctx, acme.AuthzLinkType, azID), "up")) + w.Header().Set("Location", linker.GetLink(ctx, acme.ChallengeLinkType, azID, ch.ID)) render.JSON(w, ch) } // GetCertificate ACME api for retrieving a Certificate. -func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) { +func GetCertificate(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) return } - certID := chi.URLParam(r, "certID") - cert, err := h.db.GetCertificate(ctx, certID) + certID := chi.URLParam(r, "certID") + cert, err := db.GetCertificate(ctx, certID) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate")) return diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index 67f7df30..822409df 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -3,6 +3,7 @@ package api import ( "bytes" "context" + "crypto/tls" "crypto/x509" "encoding/json" "encoding/pem" @@ -19,11 +20,33 @@ import ( "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/authority/provisioner" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" ) +type mockClient struct { + get func(url string) (*http.Response, error) + lookupTxt func(name string) ([]string, error) + tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error) +} + +func (m *mockClient) Get(u string) (*http.Response, error) { return m.get(u) } +func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) } +func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) { + return m.tlsDial(network, addr, config) +} + +func mockMustAuthority(t *testing.T, a acme.CertificateAuthority) { + t.Helper() + fn := mustAuthority + t.Cleanup(func() { + mustAuthority = fn + }) + mustAuthority = func(ctx context.Context) acme.CertificateAuthority { + return a + } +} + func TestHandler_GetNonce(t *testing.T) { tests := []struct { name string @@ -38,10 +61,10 @@ func TestHandler_GetNonce(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := &Handler{} + // h := &Handler{} w := httptest.NewRecorder() req.Method = tt.name - h.GetNonce(w, req) + GetNonce(w, req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -52,7 +75,8 @@ func TestHandler_GetNonce(t *testing.T) { } func TestHandler_GetDirectory(t *testing.T) { - linker := NewLinker("ca.smallstep.com", "acme") + linker := acme.NewLinker("ca.smallstep.com", "acme") + _ = linker type test struct { ctx context.Context statusCode int @@ -61,23 +85,14 @@ func TestHandler_GetDirectory(t *testing.T) { } var tests = map[string]func(t *testing.T) test{ "fail/no-provisioner": func(t *testing.T) test { - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), provisionerContextKey, nil) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - ctx: ctx, + ctx: context.Background(), statusCode: 500, - err: acme.NewErrorISE("provisioner in context is not an ACME provisioner"), + err: acme.NewErrorISE("provisioner is not in context"), } }, "fail/different-provisioner": func(t *testing.T) test { - prov := &provisioner.SCEP{ - Type: "SCEP", - Name: "test@scep-provisioner.com", - } - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), &fakeProvisioner{}) return test{ ctx: ctx, statusCode: 500, @@ -88,8 +103,7 @@ func TestHandler_GetDirectory(t *testing.T) { prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) expDir := Directory{ NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), @@ -108,8 +122,7 @@ func TestHandler_GetDirectory(t *testing.T) { prov.RequireEAB = true provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) expDir := Directory{ NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), @@ -130,11 +143,11 @@ func TestHandler_GetDirectory(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: linker} + ctx := acme.NewLinkerContext(tc.ctx, acme.NewLinker("test.ca.smallstep.com", "acme")) req := httptest.NewRequest("GET", "/foo/bar", nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.GetDirectory(w, req) + GetDirectory(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -219,7 +232,7 @@ func TestHandler_GetAuthorization(t *testing.T) { } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, nil) return test{ db: &acme.MockDB{}, @@ -285,10 +298,9 @@ func TestHandler_GetAuthorization(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ db: &acme.MockDB{ MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) { @@ -304,11 +316,11 @@ func TestHandler_GetAuthorization(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} + ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) req := httptest.NewRequest("GET", "/foo/bar", nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.GetAuthorization(w, req) + GetAuthorization(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -447,11 +459,11 @@ func TestHandler_GetCertificate(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db} + ctx := acme.NewDatabaseContext(tc.ctx, tc.db) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.GetCertificate(w, req) + GetCertificate(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -491,7 +503,7 @@ func TestHandler_GetChallenge(t *testing.T) { type test struct { db acme.DB - vco *acme.ValidateChallengeOptions + vc acme.Client ctx context.Context statusCode int ch *acme.Challenge @@ -500,6 +512,7 @@ func TestHandler_GetChallenge(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.Background(), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -507,6 +520,7 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/nil-account": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), accContextKey, nil), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -516,6 +530,7 @@ func TestHandler_GetChallenge(t *testing.T) { acc := &acme.Account{ID: "accID"} ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), @@ -523,10 +538,11 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/nil-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload expected in request context"), @@ -534,7 +550,7 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/db.GetChallenge-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -553,7 +569,7 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/account-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -572,7 +588,7 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/no-jwk": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -591,7 +607,7 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/nil-jwk": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, jwkContextKey, nil) @@ -611,7 +627,7 @@ func TestHandler_GetChallenge(t *testing.T) { }, "fail/validate-challenge-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -639,8 +655,8 @@ func TestHandler_GetChallenge(t *testing.T) { return acme.NewErrorISE("force") }, }, - vco: &acme.ValidateChallengeOptions{ - HTTPGet: func(string) (*http.Response, error) { + vc: &mockClient{ + get: func(string) (*http.Response, error) { return nil, errors.New("force") }, }, @@ -651,14 +667,13 @@ func TestHandler_GetChallenge(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) _pub := _jwk.Public() ctx = context.WithValue(ctx, jwkContextKey, &_pub) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ @@ -690,8 +705,8 @@ func TestHandler_GetChallenge(t *testing.T) { URL: u, Error: acme.NewError(acme.ErrorConnectionType, "force"), }, - vco: &acme.ValidateChallengeOptions{ - HTTPGet: func(string) (*http.Response, error) { + vc: &mockClient{ + get: func(string) (*http.Response, error) { return nil, errors.New("force") }, }, @@ -703,11 +718,11 @@ func TestHandler_GetChallenge(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco} + ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.GetChallenge(w, req) + GetChallenge(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) diff --git a/acme/api/middleware.go b/acme/api/middleware.go index 10f7841f..a254a83b 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -9,7 +9,6 @@ import ( "net/url" "strings" - "github.com/go-chi/chi" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" @@ -31,39 +30,11 @@ func logNonce(w http.ResponseWriter, nonce string) { } } -// baseURLFromRequest determines the base URL which should be used for -// constructing link URLs in e.g. the ACME directory result by taking the -// request Host into consideration. -// -// If the Request.Host is an empty string, we return an empty string, to -// indicate that the configured URL values should be used instead. If this -// function returns a non-empty result, then this should be used in -// constructing ACME link URLs. -func baseURLFromRequest(r *http.Request) *url.URL { - // NOTE: See https://github.com/letsencrypt/boulder/blob/master/web/relative.go - // for an implementation that allows HTTP requests using the x-forwarded-proto - // header. - - if r.Host == "" { - return nil - } - return &url.URL{Scheme: "https", Host: r.Host} -} - -// baseURLFromRequest is a middleware that extracts and caches the baseURL -// from the request. -// E.g. https://ca.smallstep.com/ -func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP { - return func(w http.ResponseWriter, r *http.Request) { - ctx := context.WithValue(r.Context(), baseURLContextKey, baseURLFromRequest(r)) - next(w, r.WithContext(ctx)) - } -} - // addNonce is a middleware that adds a nonce to the response header. -func (h *Handler) addNonce(next nextHTTP) nextHTTP { +func addNonce(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - nonce, err := h.db.CreateNonce(r.Context()) + db := acme.MustDatabaseFromContext(r.Context()) + nonce, err := db.CreateNonce(r.Context()) if err != nil { render.Error(w, err) return @@ -77,25 +48,31 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP { // addDirLink is a middleware that adds a 'Link' response reader with the // directory index url. -func (h *Handler) addDirLink(next nextHTTP) nextHTTP { +func addDirLink(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Link", link(h.linker.GetLink(r.Context(), DirectoryLinkType), "index")) + ctx := r.Context() + linker := acme.MustLinkerFromContext(ctx) + + w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index")) next(w, r) } } // verifyContentType is a middleware that verifies that content type is // application/jose+json. -func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { +func verifyContentType(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - var expected []string p, err := provisionerFromContext(r.Context()) if err != nil { render.Error(w, err) return } - u := url.URL{Path: h.linker.GetUnescapedPathSuffix(CertificateLinkType, p.GetName(), "")} + u := &url.URL{ + Path: acme.GetUnescapedPathSuffix(acme.CertificateLinkType, p.GetName(), ""), + } + + var expected []string if strings.Contains(r.URL.String(), u.EscapedPath()) { // GET /certificate requests allow a greater range of content types. expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"} @@ -117,7 +94,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { } // parseJWS is a middleware that parses a request body into a JSONWebSignature struct. -func (h *Handler) parseJWS(next nextHTTP) nextHTTP { +func parseJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) if err != nil { @@ -149,10 +126,12 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP { // * “nonce” (defined in Section 6.5) // * “url” (defined in Section 6.4) // * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below -func (h *Handler) validateJWS(next nextHTTP) nextHTTP { +func validateJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - jws, err := jwsFromContext(r.Context()) + db := acme.MustDatabaseFromContext(ctx) + + jws, err := jwsFromContext(ctx) if err != nil { render.Error(w, err) return @@ -202,7 +181,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { } // Check the validity/freshness of the Nonce. - if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil { + if err := db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil { render.Error(w, err) return } @@ -235,10 +214,12 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { // extractJWK is a middleware that extracts the JWK from the JWS and saves it // in the context. Make sure to parse and validate the JWS before running this // middleware. -func (h *Handler) extractJWK(next nextHTTP) nextHTTP { +func extractJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - jws, err := jwsFromContext(r.Context()) + db := acme.MustDatabaseFromContext(ctx) + + jws, err := jwsFromContext(ctx) if err != nil { render.Error(w, err) return @@ -264,7 +245,7 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { ctx = context.WithValue(ctx, jwkContextKey, jwk) // Get Account OR continue to generate a new one OR continue Revoke with certificate private key - acc, err := h.db.GetAccountByKeyID(ctx, jwk.KeyID) + acc, err := db.GetAccountByKeyID(ctx, jwk.KeyID) switch { case errors.Is(err, acme.ErrNotFound): // For NewAccount and Revoke requests ... @@ -283,63 +264,44 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { } } -// lookupProvisioner loads the provisioner associated with the request. -// Responds 404 if the provisioner does not exist. -func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - nameEscaped := chi.URLParam(r, "provisionerID") - name, err := url.PathUnescape(nameEscaped) - if err != nil { - render.Error(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) - return - } - p, err := h.ca.LoadProvisionerByName(name) - if err != nil { - render.Error(w, err) - return - } - acmeProv, ok := p.(*provisioner.ACME) - if !ok { - render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) - return - } - ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv)) - next(w, r.WithContext(ctx)) - } -} - // checkPrerequisites checks if all prerequisites for serving ACME // are met by the CA configuration. -func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP { +func checkPrerequisites(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - ok, err := h.prerequisitesChecker(ctx) - if err != nil { - render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) - return - } - if !ok { - render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites")) - return + // If the function is not set assume that all prerequisites are met. + checkFunc, ok := acme.PrerequisitesCheckerFromContext(ctx) + if ok { + ok, err := checkFunc(ctx) + if err != nil { + render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) + return + } + if !ok { + render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites")) + return + } } - next(w, r.WithContext(ctx)) + next(w, r) } } // lookupJWK loads the JWK associated with the acme account referenced by the // kid parameter of the signed payload. // Make sure to parse and validate the JWS before running this middleware. -func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { +func lookupJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + jws, err := jwsFromContext(ctx) if err != nil { render.Error(w, err) return } - kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "") + kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "") kid := jws.Signatures[0].Protected.KeyID if !strings.HasPrefix(kid, kidPrefix) { render.Error(w, acme.NewError(acme.ErrorMalformedType, @@ -349,7 +311,7 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { } accID := strings.TrimPrefix(kid, kidPrefix) - acc, err := h.db.GetAccount(ctx, accID) + acc, err := db.GetAccount(ctx, accID) switch { case nosql.IsErrNotFound(err): render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) @@ -372,7 +334,7 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { // extractOrLookupJWK forwards handling to either extractJWK or // lookupJWK based on the presence of a JWK or a KID, respectively. -func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP { +func extractOrLookupJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() jws, err := jwsFromContext(ctx) @@ -385,13 +347,13 @@ func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP { // and it can be used to check if a JWK exists. This flow is used when the ACME client // signed the payload with a certificate private key. if canExtractJWKFrom(jws) { - h.extractJWK(next)(w, r) + extractJWK(next)(w, r) return } // default to looking up the JWK based on KeyID. This flow is used when the ACME client // signed the payload with an account private key. - h.lookupJWK(next)(w, r) + lookupJWK(next)(w, r) } } @@ -408,7 +370,7 @@ func canExtractJWKFrom(jws *jose.JSONWebSignature) bool { // verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context. // Make sure to parse and validate the JWS before running this middleware. -func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { +func verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() jws, err := jwsFromContext(ctx) @@ -440,7 +402,7 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { } // isPostAsGet asserts that the request is a PostAsGet (empty JWS payload). -func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP { +func isPostAsGet(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { payload, err := payloadFromContext(r.Context()) if err != nil { @@ -462,16 +424,12 @@ type ContextKey string const ( // accContextKey account key accContextKey = ContextKey("acc") - // baseURLContextKey baseURL key - baseURLContextKey = ContextKey("baseURL") // jwsContextKey jws key jwsContextKey = ContextKey("jws") // jwkContextKey jwk key jwkContextKey = ContextKey("jwk") // payloadContextKey payload key payloadContextKey = ContextKey("payload") - // provisionerContextKey provisioner key - provisionerContextKey = ContextKey("provisioner") ) // accountFromContext searches the context for an ACME account. Returns the @@ -484,15 +442,6 @@ func accountFromContext(ctx context.Context) (*acme.Account, error) { return val, nil } -// baseURLFromContext returns the baseURL if one is stored in the context. -func baseURLFromContext(ctx context.Context) *url.URL { - val, ok := ctx.Value(baseURLContextKey).(*url.URL) - if !ok || val == nil { - return nil - } - return val -} - // jwkFromContext searches the context for a JWK. Returns the JWK or an error. func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) { val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey) @@ -514,29 +463,26 @@ func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) { // provisionerFromContext searches the context for a provisioner. Returns the // provisioner or an error. func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) { - val := ctx.Value(provisionerContextKey) - if val == nil { + p, ok := acme.ProvisionerFromContext(ctx) + if !ok || p == nil { return nil, acme.NewErrorISE("provisioner expected in request context") } - pval, ok := val.(acme.Provisioner) - if !ok || pval == nil { - return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner") - } - return pval, nil + return p, nil } // acmeProvisionerFromContext searches the context for an ACME provisioner. Returns // pointer to an ACME provisioner or an error. func acmeProvisionerFromContext(ctx context.Context) (*provisioner.ACME, error) { - prov, err := provisionerFromContext(ctx) + p, err := provisionerFromContext(ctx) if err != nil { return nil, err } - acmeProv, ok := prov.(*provisioner.ACME) - if !ok || acmeProv == nil { + ap, ok := p.(*provisioner.ACME) + if !ok { return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner") } - return acmeProv, nil + + return ap, nil } // payloadFromContext searches the context for a payload. Returns the payload diff --git a/acme/api/middleware_test.go b/acme/api/middleware_test.go index 8003fa16..193f5347 100644 --- a/acme/api/middleware_test.go +++ b/acme/api/middleware_test.go @@ -27,83 +27,18 @@ func testNext(w http.ResponseWriter, r *http.Request) { w.Write(testBody) } -func Test_baseURLFromRequest(t *testing.T) { - tests := []struct { - name string - targetURL string - expectedResult *url.URL - requestPreparer func(*http.Request) - }{ - { - "HTTPS host pass-through failed.", - "https://my.dummy.host", - &url.URL{Scheme: "https", Host: "my.dummy.host"}, - nil, - }, - { - "Port pass-through failed", - "https://host.with.port:8080", - &url.URL{Scheme: "https", Host: "host.with.port:8080"}, - nil, - }, - { - "Explicit host from Request.Host was not used.", - "https://some.target.host:8080", - &url.URL{Scheme: "https", Host: "proxied.host"}, - func(r *http.Request) { - r.Host = "proxied.host" - }, - }, - { - "Missing Request.Host value did not result in empty string result.", - "https://some.host", - nil, - func(r *http.Request) { - r.Host = "" - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - request := httptest.NewRequest("GET", tc.targetURL, nil) - if tc.requestPreparer != nil { - tc.requestPreparer(request) - } - result := baseURLFromRequest(request) - if result == nil || tc.expectedResult == nil { - assert.Equals(t, result, tc.expectedResult) - } else if result.String() != tc.expectedResult.String() { - t.Errorf("Expected %q, but got %q", tc.expectedResult.String(), result.String()) - } - }) - } -} - -func TestHandler_baseURLFromRequest(t *testing.T) { - h := &Handler{} - req := httptest.NewRequest("GET", "/foo", nil) - req.Host = "test.ca.smallstep.com:8080" - w := httptest.NewRecorder() - - next := func(w http.ResponseWriter, r *http.Request) { - bu := baseURLFromContext(r.Context()) - if assert.NotNil(t, bu) { - assert.Equals(t, bu.Host, "test.ca.smallstep.com:8080") - assert.Equals(t, bu.Scheme, "https") +func newBaseContext(ctx context.Context, args ...interface{}) context.Context { + for _, a := range args { + switch v := a.(type) { + case acme.DB: + ctx = acme.NewDatabaseContext(ctx, v) + case acme.Linker: + ctx = acme.NewLinkerContext(ctx, v) + case acme.PrerequisitesChecker: + ctx = acme.NewPrerequisitesCheckerContext(ctx, v) } } - - h.baseURLFromRequest(next)(w, req) - - req = httptest.NewRequest("GET", "/foo", nil) - req.Host = "" - - next = func(w http.ResponseWriter, r *http.Request) { - assert.Equals(t, baseURLFromContext(r.Context()), nil) - } - - h.baseURLFromRequest(next)(w, req) + return ctx } func TestHandler_addNonce(t *testing.T) { @@ -139,10 +74,10 @@ func TestHandler_addNonce(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db} - req := httptest.NewRequest("GET", u, nil) + ctx := newBaseContext(context.Background(), tc.db) + req := httptest.NewRequest("GET", u, nil).WithContext(ctx) w := httptest.NewRecorder() - h.addNonce(testNext)(w, req) + addNonce(testNext)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -175,17 +110,15 @@ func TestHandler_addDirLink(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} type test struct { link string - linker Linker statusCode int ctx context.Context err *acme.Error } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) + ctx = acme.NewLinkerContext(ctx, acme.NewLinker("test.ca.smallstep.com", "acme")) return test{ - linker: NewLinker("dns", "acme"), ctx: ctx, link: fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName), statusCode: 200, @@ -195,11 +128,10 @@ func TestHandler_addDirLink(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: tc.linker} req := httptest.NewRequest("GET", "/foo", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.addDirLink(testNext)(w, req) + addDirLink(testNext)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -231,7 +163,6 @@ func TestHandler_verifyContentType(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} u := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName) type test struct { - h Handler ctx context.Context contentType string err *acme.Error @@ -241,9 +172,6 @@ func TestHandler_verifyContentType(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/provisioner-not-set": func(t *testing.T) test { return test{ - h: Handler{ - linker: NewLinker("dns", "acme"), - }, url: u, ctx: context.Background(), contentType: "foo", @@ -253,11 +181,8 @@ func TestHandler_verifyContentType(t *testing.T) { }, "fail/general-bad-content-type": func(t *testing.T) test { return test{ - h: Handler{ - linker: NewLinker("dns", "acme"), - }, url: u, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "foo", statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json], but got foo"), @@ -265,10 +190,7 @@ func TestHandler_verifyContentType(t *testing.T) { }, "fail/certificate-bad-content-type": func(t *testing.T) test { return test{ - h: Handler{ - linker: NewLinker("dns", "acme"), - }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "foo", statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo"), @@ -276,40 +198,28 @@ func TestHandler_verifyContentType(t *testing.T) { }, "ok": func(t *testing.T) test { return test{ - h: Handler{ - linker: NewLinker("dns", "acme"), - }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "application/jose+json", statusCode: 200, } }, "ok/certificate/pkix-cert": func(t *testing.T) test { return test{ - h: Handler{ - linker: NewLinker("dns", "acme"), - }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "application/pkix-cert", statusCode: 200, } }, "ok/certificate/jose+json": func(t *testing.T) test { return test{ - h: Handler{ - linker: NewLinker("dns", "acme"), - }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "application/jose+json", statusCode: 200, } }, "ok/certificate/pkcs7-mime": func(t *testing.T) test { return test{ - h: Handler{ - linker: NewLinker("dns", "acme"), - }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: acme.NewProvisionerContext(context.Background(), prov), contentType: "application/pkcs7-mime", statusCode: 200, } @@ -326,7 +236,7 @@ func TestHandler_verifyContentType(t *testing.T) { req = req.WithContext(tc.ctx) req.Header.Add("Content-Type", tc.contentType) w := httptest.NewRecorder() - tc.h.verifyContentType(testNext)(w, req) + verifyContentType(testNext)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -390,11 +300,11 @@ func TestHandler_isPostAsGet(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{} + // h := &Handler{} req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.isPostAsGet(testNext)(w, req) + isPostAsGet(testNext)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -481,10 +391,10 @@ func TestHandler_parseJWS(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{} + // h := &Handler{} req := httptest.NewRequest("GET", u, tc.body) w := httptest.NewRecorder() - h.parseJWS(tc.next)(w, req) + parseJWS(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -679,11 +589,11 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{} + // h := &Handler{} req := httptest.NewRequest("GET", u, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.verifyAndExtractJWSPayload(tc.next)(w, req) + verifyAndExtractJWSPayload(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -733,7 +643,7 @@ func TestHandler_lookupJWK(t *testing.T) { parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) type test struct { - linker Linker + linker acme.Linker db acme.DB ctx context.Context next func(http.ResponseWriter, *http.Request) @@ -743,15 +653,19 @@ func TestHandler_lookupJWK(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-jws": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + db: &acme.MockDB{}, + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), + ctx: acme.NewProvisionerContext(context.Background(), prov), statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jws": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, nil) return test{ + db: &acme.MockDB{}, + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), @@ -765,11 +679,11 @@ func TestHandler_lookupJWK(t *testing.T) { assert.FatalError(t, err) _jws, err := _signer.Sign([]byte("baz")) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, _jws) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), + db: &acme.MockDB{}, + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got ", prefix), @@ -789,22 +703,21 @@ func TestHandler_lookupJWK(t *testing.T) { assert.FatalError(t, err) _parsed, err := jose.ParseJWS(_raw) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, _parsed) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), + db: &acme.MockDB{}, + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got foo", prefix), } }, "fail/account-not-found": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) { assert.Equals(t, accID, accID) @@ -817,11 +730,10 @@ func TestHandler_lookupJWK(t *testing.T) { } }, "fail/GetAccount-error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { assert.Equals(t, id, accID) @@ -835,11 +747,10 @@ func TestHandler_lookupJWK(t *testing.T) { }, "fail/account-not-valid": func(t *testing.T) test { acc := &acme.Account{Status: "deactivated"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { assert.Equals(t, id, accID) @@ -853,11 +764,10 @@ func TestHandler_lookupJWK(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{Status: "valid", Key: jwk} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { assert.Equals(t, id, accID) @@ -881,11 +791,11 @@ func TestHandler_lookupJWK(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db, linker: tc.linker} + ctx := newBaseContext(tc.ctx, tc.db, tc.linker) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.lookupJWK(tc.next)(w, req) + lookupJWK(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -945,15 +855,17 @@ func TestHandler_extractJWK(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-jws": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + db: &acme.MockDB{}, + ctx: acme.NewProvisionerContext(context.Background(), prov), statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jws": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), @@ -969,9 +881,10 @@ func TestHandler_extractJWK(t *testing.T) { }, }, } - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, _jws) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"), @@ -987,16 +900,17 @@ func TestHandler_extractJWK(t *testing.T) { }, }, } - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, _jws) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"), } }, "fail/GetAccountByKey-error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, @@ -1012,7 +926,7 @@ func TestHandler_extractJWK(t *testing.T) { }, "fail/account-not-valid": func(t *testing.T) test { acc := &acme.Account{Status: "deactivated"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, @@ -1028,7 +942,7 @@ func TestHandler_extractJWK(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{Status: "valid"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, @@ -1051,7 +965,7 @@ func TestHandler_extractJWK(t *testing.T) { } }, "ok/no-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, @@ -1077,11 +991,11 @@ func TestHandler_extractJWK(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db} + ctx := newBaseContext(tc.ctx, tc.db) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.extractJWK(tc.next)(w, req) + extractJWK(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -1118,6 +1032,7 @@ func TestHandler_validateJWS(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-jws": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.Background(), statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), @@ -1125,6 +1040,7 @@ func TestHandler_validateJWS(t *testing.T) { }, "fail/nil-jws": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, nil), statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), @@ -1132,6 +1048,7 @@ func TestHandler_validateJWS(t *testing.T) { }, "fail/no-signature": func(t *testing.T) test { return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, &jose.JSONWebSignature{}), statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"), @@ -1145,6 +1062,7 @@ func TestHandler_validateJWS(t *testing.T) { }, } return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"), @@ -1157,6 +1075,7 @@ func TestHandler_validateJWS(t *testing.T) { }, } return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"), @@ -1169,6 +1088,7 @@ func TestHandler_validateJWS(t *testing.T) { }, } return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: none"), @@ -1181,6 +1101,7 @@ func TestHandler_validateJWS(t *testing.T) { }, } return test{ + db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", jose.HS256), @@ -1444,11 +1365,11 @@ func TestHandler_validateJWS(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db} + ctx := newBaseContext(tc.ctx, tc.db) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.validateJWS(tc.next)(w, req) + validateJWS(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -1542,7 +1463,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) { u := "https://ca.smallstep.com/acme/account" type test struct { db acme.DB - linker Linker + linker acme.Linker statusCode int ctx context.Context err *acme.Error @@ -1570,7 +1491,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) { parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("dns", "acme"), db: &acme.MockDB{ MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) { assert.Equals(t, kid, pub.KeyID) @@ -1606,11 +1527,10 @@ func TestHandler_extractOrLookupJWK(t *testing.T) { parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) acc := &acme.Account{ID: "accID", Key: jwk, Status: "valid"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ - linker: NewLinker("test.ca.smallstep.com", "acme"), + linker: acme.NewLinker("test.ca.smallstep.com", "acme"), db: &acme.MockDB{ MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) { assert.Equals(t, accID, acc.ID) @@ -1628,11 +1548,11 @@ func TestHandler_extractOrLookupJWK(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db, linker: tc.linker} + ctx := newBaseContext(tc.ctx, tc.db, tc.linker) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.extractOrLookupJWK(tc.next)(w, req) + extractOrLookupJWK(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -1664,7 +1584,7 @@ func TestHandler_checkPrerequisites(t *testing.T) { u := fmt.Sprintf("%s/acme/%s/account/1234", baseURL, provName) type test struct { - linker Linker + linker acme.Linker ctx context.Context prerequisitesChecker func(context.Context) (bool, error) next func(http.ResponseWriter, *http.Request) @@ -1673,10 +1593,9 @@ func TestHandler_checkPrerequisites(t *testing.T) { } var tests = map[string]func(t *testing.T) test{ "fail/error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("dns", "acme"), ctx: ctx, prerequisitesChecker: func(context.Context) (bool, error) { return false, errors.New("force") }, next: func(w http.ResponseWriter, r *http.Request) { @@ -1687,10 +1606,9 @@ func TestHandler_checkPrerequisites(t *testing.T) { } }, "fail/prerequisites-nok": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("dns", "acme"), ctx: ctx, prerequisitesChecker: func(context.Context) (bool, error) { return false, nil }, next: func(w http.ResponseWriter, r *http.Request) { @@ -1701,10 +1619,9 @@ func TestHandler_checkPrerequisites(t *testing.T) { } }, "ok": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := acme.NewProvisionerContext(context.Background(), prov) return test{ - linker: NewLinker("dns", "acme"), + linker: acme.NewLinker("dns", "acme"), ctx: ctx, prerequisitesChecker: func(context.Context) (bool, error) { return true, nil }, next: func(w http.ResponseWriter, r *http.Request) { @@ -1717,11 +1634,11 @@ func TestHandler_checkPrerequisites(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: nil, linker: tc.linker, prerequisitesChecker: tc.prerequisitesChecker} + ctx := acme.NewPrerequisitesCheckerContext(tc.ctx, tc.prerequisitesChecker) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.checkPrerequisites(tc.next)(w, req) + checkPrerequisites(tc.next)(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) diff --git a/acme/api/order.go b/acme/api/order.go index c37285d2..4e829b42 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -72,8 +72,12 @@ var defaultOrderExpiry = time.Hour * 24 var defaultOrderBackdate = time.Minute // NewOrder ACME api for creating a new order. -func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { +func NewOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + ca := mustAuthority(ctx) + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) @@ -113,7 +117,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { var eak *acme.ExternalAccountKey if acmeProv.RequireEAB { - if eak, err = h.db.GetExternalAccountKeyByAccountID(ctx, prov.GetID(), acc.ID); err != nil { + if eak, err = db.GetExternalAccountKeyByAccountID(ctx, prov.GetID(), acc.ID); err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving external account binding key")) return } @@ -138,7 +142,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { return } // evaluate the authority level policy - if err = h.ca.AreSANsAllowed(ctx, []string{identifier.Value}); err != nil { + if err = ca.AreSANsAllowed(ctx, []string{identifier.Value}); err != nil { render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized")) return } @@ -164,7 +168,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { ExpiresAt: o.ExpiresAt, Status: acme.StatusPending, } - if err := h.newAuthorization(ctx, az); err != nil { + if err := newAuthorization(ctx, az); err != nil { render.Error(w, err) return } @@ -183,14 +187,14 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { o.NotBefore = o.NotBefore.Add(-defaultOrderBackdate) } - if err := h.db.CreateOrder(ctx, o); err != nil { + if err := db.CreateOrder(ctx, o); err != nil { render.Error(w, acme.WrapErrorISE(err, "error creating order")) return } - h.linker.LinkOrder(ctx, o) + linker.LinkOrder(ctx, o) - w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) + w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID)) render.JSONStatus(w, o, http.StatusCreated) } @@ -208,7 +212,7 @@ func newACMEPolicyEngine(eak *acme.ExternalAccountKey) (policy.X509Policy, error return policy.NewX509PolicyEngine(eak.Policy) } -func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error { +func newAuthorization(ctx context.Context, az *acme.Authorization) error { if strings.HasPrefix(az.Identifier.Value, "*.") { az.Wildcard = true az.Identifier = acme.Identifier{ @@ -224,6 +228,8 @@ func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) if err != nil { return acme.WrapErrorISE(err, "error generating random alphanumeric ID") } + + db := acme.MustDatabaseFromContext(ctx) az.Challenges = make([]*acme.Challenge, len(chTypes)) for i, typ := range chTypes { ch := &acme.Challenge{ @@ -233,20 +239,23 @@ func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) Token: az.Token, Status: acme.StatusPending, } - if err := h.db.CreateChallenge(ctx, ch); err != nil { + if err := db.CreateChallenge(ctx, ch); err != nil { return acme.WrapErrorISE(err, "error creating challenge") } az.Challenges[i] = ch } - if err = h.db.CreateAuthorization(ctx, az); err != nil { + if err = db.CreateAuthorization(ctx, az); err != nil { return acme.WrapErrorISE(err, "error creating authorization") } return nil } // GetOrder ACME api for retrieving an order. -func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { +func GetOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) @@ -257,7 +266,8 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { render.Error(w, err) return } - o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) + + o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) return @@ -272,20 +282,23 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) return } - if err = o.UpdateStatus(ctx, h.db); err != nil { + if err = o.UpdateStatus(ctx, db); err != nil { render.Error(w, acme.WrapErrorISE(err, "error updating order status")) return } - h.linker.LinkOrder(ctx, o) + linker.LinkOrder(ctx, o) - w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) + w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID)) render.JSON(w, o) } -// FinalizeOrder attemptst to finalize an order and create a certificate. -func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { +// FinalizeOrder attempts to finalize an order and create a certificate. +func FinalizeOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) @@ -312,7 +325,7 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { return } - o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) + o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) return @@ -327,14 +340,16 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) return } - if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil { + + ca := mustAuthority(ctx) + if err = o.Finalize(ctx, db, fr.csr, ca, prov); err != nil { render.Error(w, acme.WrapErrorISE(err, "error finalizing order")) return } - h.linker.LinkOrder(ctx, o) + linker.LinkOrder(ctx, o) - w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) + w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID)) render.JSON(w, o) } diff --git a/acme/api/order_test.go b/acme/api/order_test.go index 35abab65..fd438461 100644 --- a/acme/api/order_test.go +++ b/acme/api/order_test.go @@ -280,15 +280,17 @@ func TestHandler_GetOrder(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + db: &acme.MockDB{}, + ctx: acme.NewProvisionerContext(context.Background(), prov), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -298,6 +300,7 @@ func TestHandler_GetOrder(t *testing.T) { acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -305,9 +308,10 @@ func TestHandler_GetOrder(t *testing.T) { }, "fail/nil-provisioner": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, nil) + ctx := acme.NewProvisionerContext(context.Background(), nil) ctx = context.WithValue(ctx, accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -315,7 +319,7 @@ func TestHandler_GetOrder(t *testing.T) { }, "fail/db.GetOrder-error": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ @@ -329,7 +333,7 @@ func TestHandler_GetOrder(t *testing.T) { }, "fail/account-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ @@ -345,7 +349,7 @@ func TestHandler_GetOrder(t *testing.T) { }, "fail/provisioner-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ @@ -361,7 +365,7 @@ func TestHandler_GetOrder(t *testing.T) { }, "fail/order-update-error": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ @@ -385,10 +389,9 @@ func TestHandler_GetOrder(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ db: &acme.MockDB{ MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { @@ -425,11 +428,11 @@ func TestHandler_GetOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} + ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme")) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.GetOrder(w, req) + GetOrder(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -640,8 +643,8 @@ func TestHandler_newAuthorization(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - h := &Handler{db: tc.db} - if err := h.newAuthorization(context.Background(), tc.az); err != nil { + ctx := newBaseContext(context.Background(), tc.db) + if err := newAuthorization(ctx, tc.az); err != nil { if assert.NotNil(t, tc.err) { switch k := err.(type) { case *acme.Error: @@ -682,15 +685,17 @@ func TestHandler_NewOrder(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + db: &acme.MockDB{}, + ctx: acme.NewProvisionerContext(context.Background(), prov), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -700,6 +705,7 @@ func TestHandler_NewOrder(t *testing.T) { acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -707,9 +713,10 @@ func TestHandler_NewOrder(t *testing.T) { }, "fail/nil-provisioner": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, nil) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -718,8 +725,9 @@ func TestHandler_NewOrder(t *testing.T) { "fail/no-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload does not exist"), @@ -727,21 +735,23 @@ func TestHandler_NewOrder(t *testing.T) { }, "fail/nil-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, - err: acme.NewErrorISE("paylod does not exist"), + err: acme.NewErrorISE("payload does not exist"), } }, "fail/unmarshal-payload-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-order request payload: unexpected end of JSON input"), @@ -752,10 +762,11 @@ func TestHandler_NewOrder(t *testing.T) { fr := &NewOrderRequest{} b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"), @@ -770,7 +781,7 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, &acme.MockProvisioner{}) + ctx := acme.NewProvisionerContext(context.Background(), &acme.MockProvisioner{}) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ @@ -798,7 +809,7 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, acmeProv) + ctx := acme.NewProvisionerContext(context.Background(), acmeProv) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ @@ -826,7 +837,7 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, acmeProv) + ctx := acme.NewProvisionerContext(context.Background(), acmeProv) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ @@ -862,7 +873,7 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, acmeProv) + ctx := acme.NewProvisionerContext(context.Background(), acmeProv) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ @@ -905,7 +916,7 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, provWithPolicy) + ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ @@ -948,7 +959,7 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, provWithPolicy) + ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ @@ -986,7 +997,7 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ @@ -1020,7 +1031,7 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) var ( @@ -1096,10 +1107,9 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) var ( ch1, ch2, ch3, ch4 **acme.Challenge az1ID, az2ID *string @@ -1217,10 +1227,9 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) var ( ch1, ch2, ch3 **acme.Challenge az1ID *string @@ -1315,10 +1324,9 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) var ( ch1, ch2, ch3 **acme.Challenge az1ID *string @@ -1412,10 +1420,9 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) var ( ch1, ch2, ch3 **acme.Challenge az1ID *string @@ -1510,10 +1517,9 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) var ( ch1, ch2, ch3 **acme.Challenge az1ID *string @@ -1611,10 +1617,9 @@ func TestHandler_NewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, provWithPolicy) + ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) var ( ch1, ch2, ch3 **acme.Challenge az1ID *string @@ -1701,11 +1706,12 @@ func TestHandler_NewOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db, ca: tc.ca} + mockMustAuthority(t, tc.ca) + ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme")) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.NewOrder(w, req) + NewOrder(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -1738,6 +1744,7 @@ func TestHandler_NewOrder(t *testing.T) { } func TestHandler_FinalizeOrder(t *testing.T) { + mockMustAuthority(t, &mockCA{}) prov := newProv() escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} @@ -1796,15 +1803,17 @@ func TestHandler_FinalizeOrder(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + db: &acme.MockDB{}, + ctx: acme.NewProvisionerContext(context.Background(), prov), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -1814,6 +1823,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -1821,9 +1831,10 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "fail/nil-provisioner": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, nil) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -1832,8 +1843,9 @@ func TestHandler_FinalizeOrder(t *testing.T) { "fail/no-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), accContextKey, acc) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload does not exist"), @@ -1841,21 +1853,23 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "fail/nil-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, - err: acme.NewErrorISE("paylod does not exist"), + err: acme.NewErrorISE("payload does not exist"), } }, "fail/unmarshal-payload-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal finalize-order request payload: unexpected end of JSON input"), @@ -1866,10 +1880,11 @@ func TestHandler_FinalizeOrder(t *testing.T) { fr := &FinalizeRequest{} b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorMalformedType, "unable to parse csr: asn1: syntax error: sequence truncated"), @@ -1878,7 +1893,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { "fail/db.GetOrder-error": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -1893,7 +1908,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "fail/account-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -1910,7 +1925,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "fail/provisioner-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -1927,7 +1942,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "fail/order-finalize-error": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -1952,10 +1967,9 @@ func TestHandler_FinalizeOrder(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accountID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ @@ -1991,11 +2005,11 @@ func TestHandler_FinalizeOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} + ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme")) req := httptest.NewRequest("GET", u, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.FinalizeOrder(w, req) + FinalizeOrder(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) diff --git a/acme/api/revoke.go b/acme/api/revoke.go index 4b71bc22..a8b98f3f 100644 --- a/acme/api/revoke.go +++ b/acme/api/revoke.go @@ -26,9 +26,11 @@ type revokePayload struct { } // RevokeCert attempts to revoke a certificate. -func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { - +func RevokeCert(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + jws, err := jwsFromContext(ctx) if err != nil { render.Error(w, err) @@ -69,7 +71,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { } serial := certToBeRevoked.SerialNumber.String() - dbCert, err := h.db.GetCertificateBySerial(ctx, serial) + dbCert, err := db.GetCertificateBySerial(ctx, serial) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial")) return @@ -87,7 +89,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { render.Error(w, err) return } - acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account) + acmeErr := isAccountAuthorized(ctx, dbCert, certToBeRevoked, account) if acmeErr != nil { render.Error(w, acmeErr) return @@ -103,7 +105,8 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { } } - hasBeenRevokedBefore, err := h.ca.IsRevoked(serial) + ca := mustAuthority(ctx) + hasBeenRevokedBefore, err := ca.IsRevoked(serial) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate")) return @@ -130,14 +133,14 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { } options := revokeOptions(serial, certToBeRevoked, reasonCode) - err = h.ca.Revoke(ctx, options) + err = ca.Revoke(ctx, options) if err != nil { render.Error(w, wrapRevokeErr(err)) return } logRevoke(w, options) - w.Header().Add("Link", link(h.linker.GetLink(ctx, DirectoryLinkType), "index")) + w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index")) w.Write(nil) } @@ -148,7 +151,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { // the identifiers in the certificate are extracted and compared against the (valid) Authorizations // that are stored for the ACME Account. If these sets match, the Account is considered authorized // to revoke the certificate. If this check fails, the client will receive an unauthorized error. -func (h *Handler) isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error { +func isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error { if !account.IsValid() { return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil) } diff --git a/acme/api/revoke_test.go b/acme/api/revoke_test.go index 9b1fd6d5..240ac748 100644 --- a/acme/api/revoke_test.go +++ b/acme/api/revoke_test.go @@ -521,6 +521,7 @@ func TestHandler_RevokeCert(t *testing.T) { "fail/no-jws": func(t *testing.T) test { ctx := context.Background() return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), @@ -529,6 +530,7 @@ func TestHandler_RevokeCert(t *testing.T) { "fail/nil-jws": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("jws expected in request context"), @@ -537,6 +539,7 @@ func TestHandler_RevokeCert(t *testing.T) { "fail/no-provisioner": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, jws) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -544,8 +547,9 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/nil-provisioner": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, jws) - ctx = context.WithValue(ctx, provisionerContextKey, nil) + ctx = acme.NewProvisionerContext(ctx, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("provisioner does not exist"), @@ -553,8 +557,9 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/no-payload": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, jws) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload does not exist"), @@ -562,9 +567,10 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/nil-payload": func(t *testing.T) test { ctx := context.WithValue(context.Background(), jwsContextKey, jws) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("payload does not exist"), @@ -573,9 +579,10 @@ func TestHandler_RevokeCert(t *testing.T) { "fail/unmarshal-payload": func(t *testing.T) test { malformedPayload := []byte(`{"payload":malformed?}`) ctx := context.WithValue(context.Background(), jwsContextKey, jws) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = acme.NewProvisionerContext(ctx, prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: malformedPayload}) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 500, err: acme.NewErrorISE("error unmarshaling payload"), @@ -587,10 +594,11 @@ func TestHandler_RevokeCert(t *testing.T) { } wronglyEncodedPayloadBytes, err := json.Marshal(wrongPayload) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: wronglyEncodedPayloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: &acme.Error{ @@ -606,10 +614,11 @@ func TestHandler_RevokeCert(t *testing.T) { } emptyPayloadBytes, err := json.Marshal(emptyPayload) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: emptyPayloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) return test{ + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: &acme.Error{ @@ -620,7 +629,7 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/db.GetCertificateBySerial": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) db := &acme.MockDB{ @@ -638,7 +647,7 @@ func TestHandler_RevokeCert(t *testing.T) { "fail/different-certificate-contents": func(t *testing.T) test { aDifferentCert, _, err := generateCertKeyPair() assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) db := &acme.MockDB{ @@ -657,7 +666,7 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/no-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) db := &acme.MockDB{ @@ -676,7 +685,7 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, accContextKey, nil) @@ -697,11 +706,10 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/account-not-valid": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusInvalid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { @@ -727,11 +735,10 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/account-not-authorized": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { @@ -781,10 +788,9 @@ func TestHandler_RevokeCert(t *testing.T) { assert.FatalError(t, err) unauthorizedPayloadBytes, err := json.Marshal(jwsPayload) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: unauthorizedPayloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { @@ -808,11 +814,10 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/certificate-revoked-check-fails": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { @@ -842,7 +847,7 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/certificate-already-revoked": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) @@ -880,7 +885,7 @@ func TestHandler_RevokeCert(t *testing.T) { invalidReasonCodePayloadBytes, err := json.Marshal(invalidReasonPayload) assert.FatalError(t, err) acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: invalidReasonCodePayloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) @@ -918,7 +923,7 @@ func TestHandler_RevokeCert(t *testing.T) { }, } acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, mockACMEProv) + ctx := acme.NewProvisionerContext(context.Background(), mockACMEProv) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) @@ -950,7 +955,7 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/ca.Revoke": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) @@ -982,7 +987,7 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/ca.Revoke-already-revoked": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) @@ -1013,11 +1018,10 @@ func TestHandler_RevokeCert(t *testing.T) { }, "ok/using-account-key": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { @@ -1041,10 +1045,9 @@ func TestHandler_RevokeCert(t *testing.T) { assert.FatalError(t, err) jws, err := jose.ParseJWS(string(jwsBytes)) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := acme.NewProvisionerContext(context.Background(), prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { @@ -1067,11 +1070,12 @@ func TestHandler_RevokeCert(t *testing.T) { for name, setup := range tests { tc := setup(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db, ca: tc.ca} + ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme")) + mockMustAuthority(t, tc.ca) req := httptest.NewRequest("POST", revokeURL, nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.RevokeCert(w, req) + RevokeCert(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -1208,8 +1212,8 @@ func TestHandler_isAccountAuthorized(t *testing.T) { for name, setup := range tests { tc := setup(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db} - acmeErr := h.isAccountAuthorized(tc.ctx, tc.existingCert, tc.certToBeRevoked, tc.account) + // h := &Handler{db: tc.db} + acmeErr := isAccountAuthorized(tc.ctx, tc.existingCert, tc.certToBeRevoked, tc.account) expectError := tc.err != nil gotError := acmeErr != nil diff --git a/acme/challenge.go b/acme/challenge.go index 9f08bae5..8d8466bd 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -14,7 +14,6 @@ import ( "fmt" "io" "net" - "net/http" "net/url" "reflect" "strings" @@ -61,27 +60,28 @@ func (ch *Challenge) ToLog() (interface{}, error) { // type using the DB interface. // satisfactorily validated, the 'status' and 'validated' attributes are // updated. -func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { +func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey) error { // If already valid or invalid then return without performing validation. if ch.Status != StatusPending { return nil } switch ch.Type { case HTTP01: - return http01Validate(ctx, ch, db, jwk, vo) + return http01Validate(ctx, ch, db, jwk) case DNS01: - return dns01Validate(ctx, ch, db, jwk, vo) + return dns01Validate(ctx, ch, db, jwk) case TLSALPN01: - return tlsalpn01Validate(ctx, ch, db, jwk, vo) + return tlsalpn01Validate(ctx, ch, db, jwk) default: return NewErrorISE("unexpected challenge type '%s'", ch.Type) } } -func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { +func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error { u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} - resp, err := vo.HTTPGet(u.String()) + vc := MustClientFromContext(ctx) + resp, err := vc.Get(u.String()) if err != nil { return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err, "error doing http GET for url %s", u)) @@ -141,7 +141,7 @@ func tlsAlert(err error) uint8 { return 0 } -func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { +func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error { config := &tls.Config{ NextProtos: []string{"acme-tls/1"}, // https://tools.ietf.org/html/rfc8737#section-4 @@ -154,7 +154,8 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON hostPort := net.JoinHostPort(ch.Value, "443") - conn, err := vo.TLSDial("tcp", hostPort, config) + vc := MustClientFromContext(ctx) + conn, err := vc.TLSDial("tcp", hostPort, config) if err != nil { // With Go 1.17+ tls.Dial fails if there's no overlap between configured // client and server protocols. When this happens the connection is @@ -253,14 +254,15 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON "incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension")) } -func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { +func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error { // Normalize domain for wildcard DNS names // This is done to avoid making TXT lookups for domains like // _acme-challenge.*.example.com // Instead perform txt lookup for _acme-challenge.example.com domain := strings.TrimPrefix(ch.Value, "*.") - txtRecords, err := vo.LookupTxt("_acme-challenge." + domain) + vc := MustClientFromContext(ctx) + txtRecords, err := vc.LookupTxt("_acme-challenge." + domain) if err != nil { return storeError(ctx, db, ch, false, WrapError(ErrorDNSType, err, "error looking up TXT records for domain %s", domain)) @@ -376,14 +378,3 @@ func storeError(ctx context.Context, db DB, ch *Challenge, markInvalid bool, err } return nil } - -type httpGetter func(string) (*http.Response, error) -type lookupTxt func(string) ([]string, error) -type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error) - -// ValidateChallengeOptions are ACME challenge validator functions. -type ValidateChallengeOptions struct { - HTTPGet httpGetter - LookupTxt lookupTxt - TLSDial tlsDialer -} diff --git a/acme/challenge_test.go b/acme/challenge_test.go index c05b25e7..e1b6816a 100644 --- a/acme/challenge_test.go +++ b/acme/challenge_test.go @@ -29,6 +29,18 @@ import ( "github.com/smallstep/assert" ) +type mockClient struct { + get func(url string) (*http.Response, error) + lookupTxt func(name string) ([]string, error) + tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error) +} + +func (m *mockClient) Get(url string) (*http.Response, error) { return m.get(url) } +func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) } +func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) { + return m.tlsDial(network, addr, config) +} + func Test_storeError(t *testing.T) { type test struct { ch *Challenge @@ -229,7 +241,7 @@ func TestKeyAuthorization(t *testing.T) { func TestChallenge_Validate(t *testing.T) { type test struct { ch *Challenge - vo *ValidateChallengeOptions + vc Client jwk *jose.JSONWebKey db DB srv *httptest.Server @@ -273,8 +285,8 @@ func TestChallenge_Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return nil, errors.New("force") }, }, @@ -309,8 +321,8 @@ func TestChallenge_Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return nil, errors.New("force") }, }, @@ -344,8 +356,8 @@ func TestChallenge_Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return nil, errors.New("force") }, }, @@ -381,8 +393,8 @@ func TestChallenge_Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return nil, errors.New("force") }, }, @@ -416,8 +428,8 @@ func TestChallenge_Validate(t *testing.T) { } return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return nil, errors.New("force") }, }, @@ -466,8 +478,8 @@ func TestChallenge_Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -493,7 +505,8 @@ func TestChallenge_Validate(t *testing.T) { defer tc.srv.Close() } - if err := tc.ch.Validate(context.Background(), tc.db, tc.jwk, tc.vo); err != nil { + ctx := NewClientContext(context.Background(), tc.vc) + if err := tc.ch.Validate(ctx, tc.db, tc.jwk); err != nil { if assert.NotNil(t, tc.err) { switch k := err.(type) { case *Error: @@ -524,7 +537,7 @@ func (errReader) Close() error { func TestHTTP01Validate(t *testing.T) { type test struct { - vo *ValidateChallengeOptions + vc Client ch *Challenge jwk *jose.JSONWebKey db DB @@ -541,8 +554,8 @@ func TestHTTP01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return nil, errors.New("force") }, }, @@ -575,8 +588,8 @@ func TestHTTP01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return nil, errors.New("force") }, }, @@ -608,8 +621,8 @@ func TestHTTP01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusBadRequest, Body: errReader(0), @@ -645,8 +658,8 @@ func TestHTTP01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusBadRequest, Body: errReader(0), @@ -681,8 +694,8 @@ func TestHTTP01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ Body: errReader(0), }, nil @@ -704,8 +717,8 @@ func TestHTTP01Validate(t *testing.T) { jwk.Key = "foo" return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(bytes.NewBufferString("foo")), }, nil @@ -730,8 +743,8 @@ func TestHTTP01Validate(t *testing.T) { assert.FatalError(t, err) return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(bytes.NewBufferString("foo")), }, nil @@ -772,8 +785,8 @@ func TestHTTP01Validate(t *testing.T) { assert.FatalError(t, err) return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(bytes.NewBufferString("foo")), }, nil @@ -815,8 +828,8 @@ func TestHTTP01Validate(t *testing.T) { assert.FatalError(t, err) return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)), }, nil @@ -857,8 +870,8 @@ func TestHTTP01Validate(t *testing.T) { assert.FatalError(t, err) return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)), }, nil @@ -887,7 +900,8 @@ func TestHTTP01Validate(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if err := http01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { + ctx := NewClientContext(context.Background(), tc.vc) + if err := http01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil { if assert.NotNil(t, tc.err) { switch k := err.(type) { case *Error: @@ -911,7 +925,7 @@ func TestDNS01Validate(t *testing.T) { fulldomain := "*.zap.internal" domain := strings.TrimPrefix(fulldomain, "*.") type test struct { - vo *ValidateChallengeOptions + vc Client ch *Challenge jwk *jose.JSONWebKey db DB @@ -928,8 +942,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return nil, errors.New("force") }, }, @@ -963,8 +977,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return nil, errors.New("force") }, }, @@ -1001,8 +1015,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return []string{"foo"}, nil }, }, @@ -1026,8 +1040,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return []string{"foo", "bar"}, nil }, }, @@ -1068,8 +1082,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return []string{"foo", "bar"}, nil }, }, @@ -1111,8 +1125,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return []string{"foo", expected}, nil }, }, @@ -1156,8 +1170,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return []string{"foo", expected}, nil }, }, @@ -1186,7 +1200,8 @@ func TestDNS01Validate(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if err := dns01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { + ctx := NewClientContext(context.Background(), tc.vc) + if err := dns01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil { if assert.NotNil(t, tc.err) { switch k := err.(type) { case *Error: @@ -1206,6 +1221,8 @@ func TestDNS01Validate(t *testing.T) { } } +type tlsDialer func(network, addr string, config *tls.Config) (conn *tls.Conn, err error) + func newTestTLSALPNServer(validationCert *tls.Certificate) (*httptest.Server, tlsDialer) { srv := httptest.NewUnstartedServer(http.NewServeMux()) @@ -1309,7 +1326,7 @@ func TestTLSALPN01Validate(t *testing.T) { } } type test struct { - vo *ValidateChallengeOptions + vc Client ch *Challenge jwk *jose.JSONWebKey db DB @@ -1321,8 +1338,8 @@ func TestTLSALPN01Validate(t *testing.T) { ch := makeTLSCh() return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return nil, errors.New("force") }, }, @@ -1351,8 +1368,8 @@ func TestTLSALPN01Validate(t *testing.T) { ch := makeTLSCh() return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return nil, errors.New("force") }, }, @@ -1384,8 +1401,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1413,8 +1430,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return tls.Client(&noopConn{}, config), nil }, }, @@ -1443,8 +1460,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return tls.Client(&noopConn{}, config), nil }, }, @@ -1479,8 +1496,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) }, }, @@ -1516,8 +1533,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) }, }, @@ -1562,8 +1579,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1605,8 +1622,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1649,8 +1666,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1692,8 +1709,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1736,8 +1753,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, srv: srv, jwk: jwk, @@ -1758,8 +1775,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1797,8 +1814,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1841,8 +1858,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1884,8 +1901,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1924,8 +1941,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1963,8 +1980,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2008,8 +2025,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2054,8 +2071,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2100,8 +2117,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2144,8 +2161,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2189,8 +2206,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2226,8 +2243,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2253,7 +2270,8 @@ func TestTLSALPN01Validate(t *testing.T) { defer tc.srv.Close() } - if err := tlsalpn01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { + ctx := NewClientContext(context.Background(), tc.vc) + if err := tlsalpn01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil { if assert.NotNil(t, tc.err) { switch k := err.(type) { case *Error: diff --git a/acme/client.go b/acme/client.go new file mode 100644 index 00000000..31f4c975 --- /dev/null +++ b/acme/client.go @@ -0,0 +1,79 @@ +package acme + +import ( + "context" + "crypto/tls" + "net" + "net/http" + "time" +) + +// Client is the interface used to verify ACME challenges. +type Client interface { + // Get issues an HTTP GET to the specified URL. + Get(url string) (*http.Response, error) + + // LookupTXT returns the DNS TXT records for the given domain name. + LookupTxt(name string) ([]string, error) + + // TLSDial connects to the given network address using net.Dialer and then + // initiates a TLS handshake, returning the resulting TLS connection. + TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) +} + +type clientKey struct{} + +// NewClientContext adds the given client to the context. +func NewClientContext(ctx context.Context, c Client) context.Context { + return context.WithValue(ctx, clientKey{}, c) +} + +// ClientFromContext returns the current client from the given context. +func ClientFromContext(ctx context.Context) (c Client, ok bool) { + c, ok = ctx.Value(clientKey{}).(Client) + return +} + +// MustClientFromContext returns the current client from the given context. It will +// return a new instance of the client if it does not exist. +func MustClientFromContext(ctx context.Context) Client { + c, ok := ClientFromContext(ctx) + if !ok { + return NewClient() + } + return c +} + +type client struct { + http *http.Client + dialer *net.Dialer +} + +// NewClient returns an implementation of Client for verifying ACME challenges. +func NewClient() Client { + return &client{ + http: &http.Client{ + Timeout: 30 * time.Second, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + }, + dialer: &net.Dialer{ + Timeout: 30 * time.Second, + }, + } +} + +func (c *client) Get(url string) (*http.Response, error) { + return c.http.Get(url) +} + +func (c *client) LookupTxt(name string) ([]string, error) { + return net.LookupTXT(name) +} + +func (c *client) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) { + return tls.DialWithDialer(c.dialer, network, addr, config) +} diff --git a/acme/common.go b/acme/common.go index e0d96deb..3054abe1 100644 --- a/acme/common.go +++ b/acme/common.go @@ -9,6 +9,16 @@ import ( "github.com/smallstep/certificates/authority/provisioner" ) +// Clock that returns time in UTC rounded to seconds. +type Clock struct{} + +// Now returns the UTC time rounded to seconds. +func (c *Clock) Now() time.Time { + return time.Now().UTC().Truncate(time.Second) +} + +var clock Clock + // CertificateAuthority is the interface implemented by a CA authority. type CertificateAuthority interface { Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) @@ -18,15 +28,42 @@ type CertificateAuthority interface { LoadProvisionerByName(string) (provisioner.Interface, error) } -// Clock that returns time in UTC rounded to seconds. -type Clock struct{} +// NewContext adds the given acme components to the context. +func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker) context.Context { + ctx = NewDatabaseContext(ctx, db) + ctx = NewClientContext(ctx, client) + ctx = NewLinkerContext(ctx, linker) + // Prerequisite checker is optional. + if fn != nil { + ctx = NewPrerequisitesCheckerContext(ctx, fn) + } + return ctx +} -// Now returns the UTC time rounded to seconds. -func (c *Clock) Now() time.Time { - return time.Now().UTC().Truncate(time.Second) +// PrerequisitesChecker is a function that checks if all prerequisites for +// serving ACME are met by the CA configuration. +type PrerequisitesChecker func(ctx context.Context) (bool, error) + +// DefaultPrerequisitesChecker is the default PrerequisiteChecker and returns +// always true. +func DefaultPrerequisitesChecker(ctx context.Context) (bool, error) { + return true, nil } -var clock Clock +type prerequisitesKey struct{} + +// NewPrerequisitesCheckerContext adds the given PrerequisitesChecker to the +// context. +func NewPrerequisitesCheckerContext(ctx context.Context, fn PrerequisitesChecker) context.Context { + return context.WithValue(ctx, prerequisitesKey{}, fn) +} + +// PrerequisitesCheckerFromContext returns the PrerequisitesChecker in the +// context. +func PrerequisitesCheckerFromContext(ctx context.Context) (PrerequisitesChecker, bool) { + fn, ok := ctx.Value(prerequisitesKey{}).(PrerequisitesChecker) + return fn, ok && fn != nil +} // Provisioner is an interface that implements a subset of the provisioner.Interface -- // only those methods required by the ACME api/authority. @@ -40,6 +77,29 @@ type Provisioner interface { GetOptions() *provisioner.Options } +type provisionerKey struct{} + +// NewProvisionerContext adds the given provisioner to the context. +func NewProvisionerContext(ctx context.Context, v Provisioner) context.Context { + return context.WithValue(ctx, provisionerKey{}, v) +} + +// ProvisionerFromContext returns the current provisioner from the given context. +func ProvisionerFromContext(ctx context.Context) (v Provisioner, ok bool) { + v, ok = ctx.Value(provisionerKey{}).(Provisioner) + return +} + +// MustLinkerFromContext returns the current provisioner from the given context. +// It will panic if it's not in the context. +func MustProvisionerFromContext(ctx context.Context) Provisioner { + if v, ok := ProvisionerFromContext(ctx); !ok { + panic("acme provisioner is not the context") + } else { + return v + } +} + // MockProvisioner for testing type MockProvisioner struct { Mret1 interface{} diff --git a/acme/db.go b/acme/db.go index b53cb397..d7c9d5f4 100644 --- a/acme/db.go +++ b/acme/db.go @@ -49,6 +49,29 @@ type DB interface { UpdateOrder(ctx context.Context, o *Order) error } +type dbKey struct{} + +// NewDatabaseContext adds the given acme database to the context. +func NewDatabaseContext(ctx context.Context, db DB) context.Context { + return context.WithValue(ctx, dbKey{}, db) +} + +// DatabaseFromContext returns the current acme database from the given context. +func DatabaseFromContext(ctx context.Context) (db DB, ok bool) { + db, ok = ctx.Value(dbKey{}).(DB) + return +} + +// MustDatabaseFromContext returns the current database from the given context. +// It will panic if it's not in the context. +func MustDatabaseFromContext(ctx context.Context) DB { + if db, ok := DatabaseFromContext(ctx); !ok { + panic("acme database is not in the context") + } else { + return db + } +} + // MockDB is an implementation of the DB interface that should only be used as // a mock in tests. type MockDB struct { diff --git a/acme/api/linker.go b/acme/linker.go similarity index 59% rename from acme/api/linker.go rename to acme/linker.go index a605ffc3..bddc21f1 100644 --- a/acme/api/linker.go +++ b/acme/linker.go @@ -1,100 +1,19 @@ -package api +package acme import ( "context" "fmt" "net" + "net/http" "net/url" "strings" - "github.com/smallstep/certificates/acme" + "github.com/go-chi/chi" + "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/provisioner" ) -// NewLinker returns a new Directory type. -func NewLinker(dns, prefix string) Linker { - _, _, err := net.SplitHostPort(dns) - if err != nil && strings.Contains(err.Error(), "too many colons in address") { - // this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334 - // in case a port was appended to this wrong format, we try to extract the port, then check if it's - // still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of - // these cases, then the input dns is not changed. - lastIndex := strings.LastIndex(dns, ":") - hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:] - if ip := net.ParseIP(hostPart); ip != nil { - dns = "[" + hostPart + "]:" + portPart - } else if ip := net.ParseIP(dns); ip != nil { - dns = "[" + dns + "]" - } - } - return &linker{prefix: prefix, dns: dns} -} - -// Linker interface for generating links for ACME resources. -type Linker interface { - GetLink(ctx context.Context, typ LinkType, inputs ...string) string - GetUnescapedPathSuffix(typ LinkType, provName string, inputs ...string) string - - LinkOrder(ctx context.Context, o *acme.Order) - LinkAccount(ctx context.Context, o *acme.Account) - LinkChallenge(ctx context.Context, o *acme.Challenge, azID string) - LinkAuthorization(ctx context.Context, o *acme.Authorization) - LinkOrdersByAccountID(ctx context.Context, orders []string) -} - -// linker generates ACME links. -type linker struct { - prefix string - dns string -} - -func (l *linker) GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string { - switch typ { - case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType: - return fmt.Sprintf("/%s/%s", provisionerName, typ) - case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType: - return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0]) - case ChallengeLinkType: - return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1]) - case OrdersByAccountLinkType: - return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0]) - case FinalizeLinkType: - return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0]) - default: - return "" - } -} - -// GetLink is a helper for GetLinkExplicit -func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string { - var ( - provName string - baseURL = baseURLFromContext(ctx) - u = url.URL{} - ) - if p, err := provisionerFromContext(ctx); err == nil && p != nil { - provName = p.GetName() - } - // Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351 - if baseURL != nil { - u = *baseURL - } - - u.Path = l.GetUnescapedPathSuffix(typ, provName, inputs...) - - // If no Scheme is set, then default to https. - if u.Scheme == "" { - u.Scheme = "https" - } - - // If no Host is set, then use the default (first DNS attr in the ca.json). - if u.Host == "" { - u.Host = l.dns - } - - u.Path = l.prefix + u.Path - return u.String() -} - // LinkType captures the link type. type LinkType int @@ -160,8 +79,155 @@ func (l LinkType) String() string { } } +func GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string { + switch typ { + case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType: + return fmt.Sprintf("/%s/%s", provisionerName, typ) + case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType: + return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0]) + case ChallengeLinkType: + return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1]) + case OrdersByAccountLinkType: + return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0]) + case FinalizeLinkType: + return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0]) + default: + return "" + } +} + +// NewLinker returns a new Directory type. +func NewLinker(dns, prefix string) Linker { + _, _, err := net.SplitHostPort(dns) + if err != nil && strings.Contains(err.Error(), "too many colons in address") { + // this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334 + // in case a port was appended to this wrong format, we try to extract the port, then check if it's + // still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of + // these cases, then the input dns is not changed. + lastIndex := strings.LastIndex(dns, ":") + hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:] + if ip := net.ParseIP(hostPart); ip != nil { + dns = "[" + hostPart + "]:" + portPart + } else if ip := net.ParseIP(dns); ip != nil { + dns = "[" + dns + "]" + } + } + return &linker{prefix: prefix, dns: dns} +} + +// Linker interface for generating links for ACME resources. +type Linker interface { + GetLink(ctx context.Context, typ LinkType, inputs ...string) string + Middleware(http.Handler) http.Handler + LinkOrder(ctx context.Context, o *Order) + LinkAccount(ctx context.Context, o *Account) + LinkChallenge(ctx context.Context, o *Challenge, azID string) + LinkAuthorization(ctx context.Context, o *Authorization) + LinkOrdersByAccountID(ctx context.Context, orders []string) +} + +type linkerKey struct{} + +// NewLinkerContext adds the given linker to the context. +func NewLinkerContext(ctx context.Context, v Linker) context.Context { + return context.WithValue(ctx, linkerKey{}, v) +} + +// LinkerFromContext returns the current linker from the given context. +func LinkerFromContext(ctx context.Context) (v Linker, ok bool) { + v, ok = ctx.Value(linkerKey{}).(Linker) + return +} + +// MustLinkerFromContext returns the current linker from the given context. It +// will panic if it's not in the context. +func MustLinkerFromContext(ctx context.Context) Linker { + if v, ok := LinkerFromContext(ctx); !ok { + panic("acme linker is not the context") + } else { + return v + } +} + +type baseURLKey struct{} + +func newBaseURLContext(ctx context.Context, r *http.Request) context.Context { + var u *url.URL + if r.Host != "" { + u = &url.URL{Scheme: "https", Host: r.Host} + } + return context.WithValue(ctx, baseURLKey{}, u) +} + +func baseURLFromContext(ctx context.Context) *url.URL { + if u, ok := ctx.Value(baseURLKey{}).(*url.URL); ok { + return u + } + return nil +} + +// linker generates ACME links. +type linker struct { + prefix string + dns string +} + +// Middleware gets the provisioner and current url from the request and sets +// them in the context so we can use the linker to create ACME links. +func (l *linker) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Add base url to the context. + ctx := newBaseURLContext(r.Context(), r) + + // Add provisioner to the context. + nameEscaped := chi.URLParam(r, "provisionerID") + name, err := url.PathUnescape(nameEscaped) + if err != nil { + render.Error(w, WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) + return + } + + p, err := authority.MustFromContext(ctx).LoadProvisionerByName(name) + if err != nil { + render.Error(w, err) + return + } + + acmeProv, ok := p.(*provisioner.ACME) + if !ok { + render.Error(w, NewError(ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) + return + } + + ctx = NewProvisionerContext(ctx, Provisioner(acmeProv)) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// GetLink is a helper for GetLinkExplicit. +func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string { + var name string + if p, ok := ProvisionerFromContext(ctx); ok { + name = p.GetName() + } + + var u url.URL + if baseURL := baseURLFromContext(ctx); baseURL != nil { + u = *baseURL + } + if u.Scheme == "" { + u.Scheme = "https" + } + if u.Host == "" { + u.Host = l.dns + } + + u.Path = l.prefix + GetUnescapedPathSuffix(typ, name, inputs...) + return u.String() +} + // LinkOrder sets the ACME links required by an ACME order. -func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) { +func (l *linker) LinkOrder(ctx context.Context, o *Order) { o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs)) for i, azID := range o.AuthorizationIDs { o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, azID) @@ -173,17 +239,17 @@ func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) { } // LinkAccount sets the ACME links required by an ACME account. -func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) { +func (l *linker) LinkAccount(ctx context.Context, acc *Account) { acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, acc.ID) } // LinkChallenge sets the ACME links required by an ACME challenge. -func (l *linker) LinkChallenge(ctx context.Context, ch *acme.Challenge, azID string) { +func (l *linker) LinkChallenge(ctx context.Context, ch *Challenge, azID string) { ch.URL = l.GetLink(ctx, ChallengeLinkType, azID, ch.ID) } // LinkAuthorization sets the ACME links required by an ACME authorization. -func (l *linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) { +func (l *linker) LinkAuthorization(ctx context.Context, az *Authorization) { for _, ch := range az.Challenges { l.LinkChallenge(ctx, ch, az.ID) } diff --git a/acme/api/linker_test.go b/acme/linker_test.go similarity index 82% rename from acme/api/linker_test.go rename to acme/linker_test.go index 74c2c8b0..b85d1a53 100644 --- a/acme/api/linker_test.go +++ b/acme/linker_test.go @@ -1,21 +1,38 @@ -package api +package acme import ( "context" "fmt" "net/url" "testing" + "time" "github.com/smallstep/assert" - "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/authority/provisioner" ) -func TestLinker_GetUnescapedPathSuffix(t *testing.T) { - dns := "ca.smallstep.com" - prefix := "acme" - linker := NewLinker(dns, prefix) +func mockProvisioner(t *testing.T) Provisioner { + t.Helper() + var defaultDisableRenewal = false + + // Initialize provisioners + p := &provisioner.ACME{ + Type: "ACME", + Name: "test@acme-provisioner.com", + } + if err := p.Init(provisioner.Config{Claims: provisioner.Claims{ + MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, + MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DisableRenewal: &defaultDisableRenewal, + }}); err != nil { + fmt.Printf("%v", err) + } + return p +} - getPath := linker.GetUnescapedPathSuffix +func TestGetUnescapedPathSuffix(t *testing.T) { + getPath := GetUnescapedPathSuffix assert.Equals(t, getPath(NewNonceLinkType, "{provisionerID}"), "/{provisionerID}/new-nonce") assert.Equals(t, getPath(DirectoryLinkType, "{provisionerID}"), "/{provisionerID}/directory") @@ -32,9 +49,9 @@ func TestLinker_GetUnescapedPathSuffix(t *testing.T) { } func TestLinker_DNS(t *testing.T) { - prov := newProv() + prov := mockProvisioner(t) escProvName := url.PathEscape(prov.GetName()) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := NewProvisionerContext(context.Background(), prov) type test struct { name string dns string @@ -117,19 +134,19 @@ func TestLinker_GetLink(t *testing.T) { linker := NewLinker(dns, prefix) id := "1234" - prov := newProv() + prov := mockProvisioner(t) escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, baseURLKey{}, baseURL) // No provisioner and no BaseURL from request assert.Equals(t, linker.GetLink(context.Background(), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", "")) // Provisioner: yes, BaseURL: no - assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerContextKey, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName)) + assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerKey{}, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName)) // Provisioner: no, BaseURL: yes - assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLContextKey, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", "")) + assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLKey{}, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", "")) assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName)) @@ -163,37 +180,37 @@ func TestLinker_GetLink(t *testing.T) { func TestLinker_LinkOrder(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - prov := newProv() + prov := mockProvisioner(t) provName := url.PathEscape(prov.GetName()) - ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, baseURLKey{}, baseURL) oid := "orderID" certID := "certID" linkerPrefix := "acme" l := NewLinker("dns", linkerPrefix) type test struct { - o *acme.Order - validate func(o *acme.Order) + o *Order + validate func(o *Order) } var tests = map[string]test{ "no-authz-and-no-cert": { - o: &acme.Order{ + o: &Order{ ID: oid, }, - validate: func(o *acme.Order) { + validate: func(o *Order) { assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) assert.Equals(t, o.AuthorizationURLs, []string{}) assert.Equals(t, o.CertificateURL, "") }, }, "one-authz-and-cert": { - o: &acme.Order{ + o: &Order{ ID: oid, CertificateID: certID, AuthorizationIDs: []string{"foo"}, }, - validate: func(o *acme.Order) { + validate: func(o *Order) { assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) assert.Equals(t, o.AuthorizationURLs, []string{ fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"), @@ -202,12 +219,12 @@ func TestLinker_LinkOrder(t *testing.T) { }, }, "many-authz": { - o: &acme.Order{ + o: &Order{ ID: oid, CertificateID: certID, AuthorizationIDs: []string{"foo", "bar", "zap"}, }, - validate: func(o *acme.Order) { + validate: func(o *Order) { assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) assert.Equals(t, o.AuthorizationURLs, []string{ fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"), @@ -228,24 +245,24 @@ func TestLinker_LinkOrder(t *testing.T) { func TestLinker_LinkAccount(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - prov := newProv() + prov := mockProvisioner(t) provName := url.PathEscape(prov.GetName()) - ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, baseURLKey{}, baseURL) accID := "accountID" linkerPrefix := "acme" l := NewLinker("dns", linkerPrefix) type test struct { - a *acme.Account - validate func(o *acme.Account) + a *Account + validate func(o *Account) } var tests = map[string]test{ "ok": { - a: &acme.Account{ + a: &Account{ ID: accID, }, - validate: func(a *acme.Account) { + validate: func(a *Account) { assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID)) }, }, @@ -260,25 +277,25 @@ func TestLinker_LinkAccount(t *testing.T) { func TestLinker_LinkChallenge(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - prov := newProv() + prov := mockProvisioner(t) provName := url.PathEscape(prov.GetName()) - ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, baseURLKey{}, baseURL) chID := "chID" azID := "azID" linkerPrefix := "acme" l := NewLinker("dns", linkerPrefix) type test struct { - ch *acme.Challenge - validate func(o *acme.Challenge) + ch *Challenge + validate func(o *Challenge) } var tests = map[string]test{ "ok": { - ch: &acme.Challenge{ + ch: &Challenge{ ID: chID, }, - validate: func(ch *acme.Challenge) { + validate: func(ch *Challenge) { assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, azID, ch.ID)) }, }, @@ -293,10 +310,10 @@ func TestLinker_LinkChallenge(t *testing.T) { func TestLinker_LinkAuthorization(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - prov := newProv() + prov := mockProvisioner(t) provName := url.PathEscape(prov.GetName()) - ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, baseURLKey{}, baseURL) chID0 := "chID-0" chID1 := "chID-1" @@ -305,20 +322,20 @@ func TestLinker_LinkAuthorization(t *testing.T) { linkerPrefix := "acme" l := NewLinker("dns", linkerPrefix) type test struct { - az *acme.Authorization - validate func(o *acme.Authorization) + az *Authorization + validate func(o *Authorization) } var tests = map[string]test{ "ok": { - az: &acme.Authorization{ + az: &Authorization{ ID: azID, - Challenges: []*acme.Challenge{ + Challenges: []*Challenge{ {ID: chID0}, {ID: chID1}, {ID: chID2}, }, }, - validate: func(az *acme.Authorization) { + validate: func(az *Authorization) { assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0)) assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1)) assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2)) @@ -335,10 +352,10 @@ func TestLinker_LinkAuthorization(t *testing.T) { func TestLinker_LinkOrdersByAccountID(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - prov := newProv() + prov := mockProvisioner(t) provName := url.PathEscape(prov.GetName()) - ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, baseURLKey{}, baseURL) linkerPrefix := "acme" l := NewLinker("dns", linkerPrefix) diff --git a/api/api.go b/api/api.go index da6309fd..75d26237 100644 --- a/api/api.go +++ b/api/api.go @@ -35,7 +35,6 @@ type Authority interface { SSHAuthority // context specifies the Authorize[Sign|Revoke|etc.] method. Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) - AuthorizeSign(ott string) ([]provisioner.SignOption, error) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) GetTLSOptions() *config.TLSOptions Root(shasum string) (*x509.Certificate, error) @@ -52,6 +51,11 @@ type Authority interface { Version() authority.Version } +// mustAuthority will be replaced on unit tests. +var mustAuthority = func(ctx context.Context) Authority { + return authority.MustFromContext(ctx) +} + // TimeDuration is an alias of provisioner.TimeDuration type TimeDuration = provisioner.TimeDuration @@ -243,48 +247,53 @@ type caHandler struct { Authority Authority } -// New creates a new RouterHandler with the CA endpoints. -func New(auth Authority) RouterHandler { - return &caHandler{ - Authority: auth, - } +// Route configures the http request router. +func (h *caHandler) Route(r Router) { + Route(r) } -func (h *caHandler) Route(r Router) { - r.MethodFunc("GET", "/version", h.Version) - r.MethodFunc("GET", "/health", h.Health) - r.MethodFunc("GET", "/root/{sha}", h.Root) - r.MethodFunc("POST", "/sign", h.Sign) - r.MethodFunc("POST", "/renew", h.Renew) - r.MethodFunc("POST", "/rekey", h.Rekey) - r.MethodFunc("POST", "/revoke", h.Revoke) - r.MethodFunc("GET", "/provisioners", h.Provisioners) - r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey) - r.MethodFunc("GET", "/roots", h.Roots) - r.MethodFunc("GET", "/roots.pem", h.RootsPEM) - r.MethodFunc("GET", "/federation", h.Federation) +// New creates a new RouterHandler with the CA endpoints. +// +// Deprecated: Use api.Route(r Router) +func New(auth Authority) RouterHandler { + return &caHandler{} +} + +func Route(r Router) { + r.MethodFunc("GET", "/version", Version) + r.MethodFunc("GET", "/health", Health) + r.MethodFunc("GET", "/root/{sha}", Root) + r.MethodFunc("POST", "/sign", Sign) + r.MethodFunc("POST", "/renew", Renew) + r.MethodFunc("POST", "/rekey", Rekey) + r.MethodFunc("POST", "/revoke", Revoke) + r.MethodFunc("GET", "/provisioners", Provisioners) + r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", ProvisionerKey) + r.MethodFunc("GET", "/roots", Roots) + r.MethodFunc("GET", "/roots.pem", RootsPEM) + r.MethodFunc("GET", "/federation", Federation) // SSH CA - r.MethodFunc("POST", "/ssh/sign", h.SSHSign) - r.MethodFunc("POST", "/ssh/renew", h.SSHRenew) - r.MethodFunc("POST", "/ssh/revoke", h.SSHRevoke) - r.MethodFunc("POST", "/ssh/rekey", h.SSHRekey) - r.MethodFunc("GET", "/ssh/roots", h.SSHRoots) - r.MethodFunc("GET", "/ssh/federation", h.SSHFederation) - r.MethodFunc("POST", "/ssh/config", h.SSHConfig) - r.MethodFunc("POST", "/ssh/config/{type}", h.SSHConfig) - r.MethodFunc("POST", "/ssh/check-host", h.SSHCheckHost) - r.MethodFunc("GET", "/ssh/hosts", h.SSHGetHosts) - r.MethodFunc("POST", "/ssh/bastion", h.SSHBastion) + r.MethodFunc("POST", "/ssh/sign", SSHSign) + r.MethodFunc("POST", "/ssh/renew", SSHRenew) + r.MethodFunc("POST", "/ssh/revoke", SSHRevoke) + r.MethodFunc("POST", "/ssh/rekey", SSHRekey) + r.MethodFunc("GET", "/ssh/roots", SSHRoots) + r.MethodFunc("GET", "/ssh/federation", SSHFederation) + r.MethodFunc("POST", "/ssh/config", SSHConfig) + r.MethodFunc("POST", "/ssh/config/{type}", SSHConfig) + r.MethodFunc("POST", "/ssh/check-host", SSHCheckHost) + r.MethodFunc("GET", "/ssh/hosts", SSHGetHosts) + r.MethodFunc("POST", "/ssh/bastion", SSHBastion) // For compatibility with old code: - r.MethodFunc("POST", "/re-sign", h.Renew) - r.MethodFunc("POST", "/sign-ssh", h.SSHSign) - r.MethodFunc("GET", "/ssh/get-hosts", h.SSHGetHosts) + r.MethodFunc("POST", "/re-sign", Renew) + r.MethodFunc("POST", "/sign-ssh", SSHSign) + r.MethodFunc("GET", "/ssh/get-hosts", SSHGetHosts) } // Version is an HTTP handler that returns the version of the server. -func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) { - v := h.Authority.Version() +func Version(w http.ResponseWriter, r *http.Request) { + v := mustAuthority(r.Context()).Version() render.JSON(w, VersionResponse{ Version: v.Version, RequireClientAuthentication: v.RequireClientAuthentication, @@ -292,17 +301,17 @@ func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) { } // Health is an HTTP handler that returns the status of the server. -func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) { +func Health(w http.ResponseWriter, r *http.Request) { render.JSON(w, HealthResponse{Status: "ok"}) } // Root is an HTTP handler that using the SHA256 from the URL, returns the root // certificate for the given SHA256. -func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) { +func Root(w http.ResponseWriter, r *http.Request) { sha := chi.URLParam(r, "sha") sum := strings.ToLower(strings.ReplaceAll(sha, "-", "")) // Load root certificate with the - cert, err := h.Authority.Root(sum) + cert, err := mustAuthority(r.Context()).Root(sum) if err != nil { render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI)) return @@ -320,18 +329,19 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate { } // Provisioners returns the list of provisioners configured in the authority. -func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { +func Provisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := ParseCursor(r) if err != nil { render.Error(w, err) return } - p, next, err := h.Authority.GetProvisioners(cursor, limit) + p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit) if err != nil { render.Error(w, errs.InternalServerErr(err)) return } + render.JSON(w, &ProvisionersResponse{ Provisioners: p, NextCursor: next, @@ -339,19 +349,20 @@ func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { } // ProvisionerKey returns the encrypted key of a provisioner by it's key id. -func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) { +func ProvisionerKey(w http.ResponseWriter, r *http.Request) { kid := chi.URLParam(r, "kid") - key, err := h.Authority.GetEncryptedKey(kid) + key, err := mustAuthority(r.Context()).GetEncryptedKey(kid) if err != nil { render.Error(w, errs.NotFoundErr(err)) return } + render.JSON(w, &ProvisionerKeyResponse{key}) } // Roots returns all the root certificates for the CA. -func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { - roots, err := h.Authority.GetRoots() +func Roots(w http.ResponseWriter, r *http.Request) { + roots, err := mustAuthority(r.Context()).GetRoots() if err != nil { render.Error(w, errs.ForbiddenErr(err, "error getting roots")) return @@ -368,8 +379,8 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { } // RootsPEM returns all the root certificates for the CA in PEM format. -func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) { - roots, err := h.Authority.GetRoots() +func RootsPEM(w http.ResponseWriter, r *http.Request) { + roots, err := mustAuthority(r.Context()).GetRoots() if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -391,8 +402,8 @@ func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) { } // Federation returns all the public certificates in the federation. -func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { - federated, err := h.Authority.GetFederation() +func Federation(w http.ResponseWriter, r *http.Request) { + federated, err := mustAuthority(r.Context()).GetFederation() if err != nil { render.Error(w, errs.ForbiddenErr(err, "error getting federated roots")) return diff --git a/api/api_test.go b/api/api_test.go index 39c77de7..1f27ab8c 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -171,10 +171,21 @@ func parseCertificateRequest(data string) *x509.CertificateRequest { return csr } +func mockMustAuthority(t *testing.T, a Authority) { + t.Helper() + fn := mustAuthority + t.Cleanup(func() { + mustAuthority = fn + }) + mustAuthority = func(ctx context.Context) Authority { + return a + } +} + type mockAuthority struct { ret1, ret2 interface{} err error - authorizeSign func(ott string) ([]provisioner.SignOption, error) + authorize func(ctx context.Context, ott string) ([]provisioner.SignOption, error) authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error) getTLSOptions func() *authority.TLSOptions root func(shasum string) (*x509.Certificate, error) @@ -203,12 +214,8 @@ type mockAuthority struct { // TODO: remove once Authorize is deprecated. func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) { - return m.AuthorizeSign(ott) -} - -func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) { - if m.authorizeSign != nil { - return m.authorizeSign(ott) + if m.authorize != nil { + return m.authorize(ctx, ott) } return m.ret1.([]provisioner.SignOption), m.err } @@ -789,11 +796,10 @@ func Test_caHandler_Route(t *testing.T) { } } -func Test_caHandler_Health(t *testing.T) { +func Test_Health(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/health", nil) w := httptest.NewRecorder() - h := New(&mockAuthority{}).(*caHandler) - h.Health(w, req) + Health(w, req) res := w.Result() if res.StatusCode != 200 { @@ -811,7 +817,7 @@ func Test_caHandler_Health(t *testing.T) { } } -func Test_caHandler_Root(t *testing.T) { +func Test_Root(t *testing.T) { tests := []struct { name string root *x509.Certificate @@ -832,9 +838,9 @@ func Test_caHandler_Root(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ret1: tt.root, err: tt.err}).(*caHandler) + mockMustAuthority(t, &mockAuthority{ret1: tt.root, err: tt.err}) w := httptest.NewRecorder() - h.Root(w, req) + Root(w, req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -855,7 +861,7 @@ func Test_caHandler_Root(t *testing.T) { } } -func Test_caHandler_Sign(t *testing.T) { +func Test_Sign(t *testing.T) { csr := parseCertificateRequest(csrPEM) valid, err := json.Marshal(SignRequest{ CsrPEM: CertificateRequest{csr}, @@ -896,18 +902,18 @@ func Test_caHandler_Sign(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.signErr, - authorizeSign: func(ott string) ([]provisioner.SignOption, error) { + authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return tt.certAttrOpts, tt.autherr }, getTLSOptions: func() *authority.TLSOptions { return nil }, - }).(*caHandler) + }) req := httptest.NewRequest("POST", "http://example.com/sign", strings.NewReader(tt.input)) w := httptest.NewRecorder() - h.Sign(logging.NewResponseLogger(w), req) + Sign(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -928,7 +934,7 @@ func Test_caHandler_Sign(t *testing.T) { } } -func Test_caHandler_Renew(t *testing.T) { +func Test_Renew(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } @@ -1018,7 +1024,7 @@ func Test_caHandler_Renew(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.err, authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) { jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root}) @@ -1039,12 +1045,12 @@ func Test_caHandler_Renew(t *testing.T) { getTLSOptions: func() *authority.TLSOptions { return nil }, - }).(*caHandler) + }) req := httptest.NewRequest("POST", "http://example.com/renew", nil) req.TLS = tt.tls req.Header = tt.header w := httptest.NewRecorder() - h.Renew(logging.NewResponseLogger(w), req) + Renew(logging.NewResponseLogger(w), req) res := w.Result() defer res.Body.Close() @@ -1073,7 +1079,7 @@ func Test_caHandler_Renew(t *testing.T) { } } -func Test_caHandler_Rekey(t *testing.T) { +func Test_Rekey(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } @@ -1104,16 +1110,16 @@ func Test_caHandler_Rekey(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.err, getTLSOptions: func() *authority.TLSOptions { return nil }, - }).(*caHandler) + }) req := httptest.NewRequest("POST", "http://example.com/rekey", strings.NewReader(tt.input)) req.TLS = tt.tls w := httptest.NewRecorder() - h.Rekey(logging.NewResponseLogger(w), req) + Rekey(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -1134,7 +1140,7 @@ func Test_caHandler_Rekey(t *testing.T) { } } -func Test_caHandler_Provisioners(t *testing.T) { +func Test_Provisioners(t *testing.T) { type fields struct { Authority Authority } @@ -1200,10 +1206,8 @@ func Test_caHandler_Provisioners(t *testing.T) { assert.FatalError(t, err) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := &caHandler{ - Authority: tt.fields.Authority, - } - h.Provisioners(tt.args.w, tt.args.r) + mockMustAuthority(t, tt.fields.Authority) + Provisioners(tt.args.w, tt.args.r) rec := tt.args.w.(*httptest.ResponseRecorder) res := rec.Result() @@ -1238,7 +1242,7 @@ func Test_caHandler_Provisioners(t *testing.T) { } } -func Test_caHandler_ProvisionerKey(t *testing.T) { +func Test_ProvisionerKey(t *testing.T) { type fields struct { Authority Authority } @@ -1270,10 +1274,8 @@ func Test_caHandler_ProvisionerKey(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := &caHandler{ - Authority: tt.fields.Authority, - } - h.ProvisionerKey(tt.args.w, tt.args.r) + mockMustAuthority(t, tt.fields.Authority) + ProvisionerKey(tt.args.w, tt.args.r) rec := tt.args.w.(*httptest.ResponseRecorder) res := rec.Result() @@ -1298,7 +1300,7 @@ func Test_caHandler_ProvisionerKey(t *testing.T) { } } -func Test_caHandler_Roots(t *testing.T) { +func Test_Roots(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } @@ -1319,11 +1321,11 @@ func Test_caHandler_Roots(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) + mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}) req := httptest.NewRequest("GET", "http://example.com/roots", nil) req.TLS = tt.tls w := httptest.NewRecorder() - h.Roots(w, req) + Roots(w, req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -1360,10 +1362,10 @@ func Test_caHandler_RootsPEM(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ret1: tt.roots, err: tt.err}).(*caHandler) + mockMustAuthority(t, &mockAuthority{ret1: tt.roots, err: tt.err}) req := httptest.NewRequest("GET", "https://example.com/roots", nil) w := httptest.NewRecorder() - h.RootsPEM(w, req) + RootsPEM(w, req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -1384,7 +1386,7 @@ func Test_caHandler_RootsPEM(t *testing.T) { } } -func Test_caHandler_Federation(t *testing.T) { +func Test_Federation(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } @@ -1405,11 +1407,11 @@ func Test_caHandler_Federation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) + mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}) req := httptest.NewRequest("GET", "http://example.com/federation", nil) req.TLS = tt.tls w := httptest.NewRecorder() - h.Federation(w, req) + Federation(w, req) res := w.Result() if res.StatusCode != tt.statusCode { diff --git a/api/rekey.go b/api/rekey.go index 3116cf74..cda843a3 100644 --- a/api/rekey.go +++ b/api/rekey.go @@ -27,7 +27,7 @@ func (s *RekeyRequest) Validate() error { } // Rekey is similar to renew except that the certificate will be renewed with new key from csr. -func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { +func Rekey(w http.ResponseWriter, r *http.Request) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { render.Error(w, errs.BadRequest("missing client certificate")) return @@ -44,7 +44,8 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { return } - certChain, err := h.Authority.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey) + a := mustAuthority(r.Context()) + certChain, err := a.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey) if err != nil { render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey")) return @@ -60,6 +61,6 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, - TLSOptions: h.Authority.GetTLSOptions(), + TLSOptions: a.GetTLSOptions(), }, http.StatusCreated) } diff --git a/api/renew.go b/api/renew.go index 9c4bff32..6e9f680f 100644 --- a/api/renew.go +++ b/api/renew.go @@ -16,14 +16,15 @@ const ( // Renew uses the information of certificate in the TLS connection to create a // new one. -func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { - cert, err := h.getPeerCertificate(r) +func Renew(w http.ResponseWriter, r *http.Request) { + cert, err := getPeerCertificate(r) if err != nil { render.Error(w, err) return } - certChain, err := h.Authority.Renew(cert) + a := mustAuthority(r.Context()) + certChain, err := a.Renew(cert) if err != nil { render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) return @@ -39,17 +40,18 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, - TLSOptions: h.Authority.GetTLSOptions(), + TLSOptions: a.GetTLSOptions(), }, http.StatusCreated) } -func (h *caHandler) getPeerCertificate(r *http.Request) (*x509.Certificate, error) { +func getPeerCertificate(r *http.Request) (*x509.Certificate, error) { if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { return r.TLS.PeerCertificates[0], nil } if s := r.Header.Get(authorizationHeader); s != "" { if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 { - return h.Authority.AuthorizeRenewToken(r.Context(), parts[1]) + ctx := r.Context() + return mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1]) } } return nil, errs.BadRequest("missing client certificate") diff --git a/api/revoke.go b/api/revoke.go index c9da2c18..aebbb875 100644 --- a/api/revoke.go +++ b/api/revoke.go @@ -1,7 +1,6 @@ package api import ( - "context" "net/http" "golang.org/x/crypto/ocsp" @@ -49,7 +48,7 @@ func (r *RevokeRequest) Validate() (err error) { // NOTE: currently only Passive revocation is supported. // // TODO: Add CRL and OCSP support. -func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { +func Revoke(w http.ResponseWriter, r *http.Request) { var body RevokeRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -68,12 +67,14 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { PassiveOnly: body.Passive, } - ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod) + ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.RevokeMethod) + a := mustAuthority(ctx) + // A token indicates that we are using the api via a provisioner token, // otherwise it is assumed that the certificate is revoking itself over mTLS. if len(body.OTT) > 0 { logOtt(w, body.OTT) - if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { + if _, err := a.Authorize(ctx, body.OTT); err != nil { render.Error(w, errs.UnauthorizedErr(err)) return } @@ -98,7 +99,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { opts.MTLS = true } - if err := h.Authority.Revoke(ctx, opts); err != nil { + if err := a.Revoke(ctx, opts); err != nil { render.Error(w, errs.ForbiddenErr(err, "error revoking certificate")) return } diff --git a/api/revoke_test.go b/api/revoke_test.go index 7635ce68..c3fa6ceb 100644 --- a/api/revoke_test.go +++ b/api/revoke_test.go @@ -108,7 +108,7 @@ func Test_caHandler_Revoke(t *testing.T) { input: string(input), statusCode: http.StatusOK, auth: &mockAuthority{ - authorizeSign: func(ott string) ([]provisioner.SignOption, error) { + authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return nil, nil }, revoke: func(ctx context.Context, opts *authority.RevokeOptions) error { @@ -152,7 +152,7 @@ func Test_caHandler_Revoke(t *testing.T) { statusCode: http.StatusOK, tls: cs, auth: &mockAuthority{ - authorizeSign: func(ott string) ([]provisioner.SignOption, error) { + authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return nil, nil }, revoke: func(ctx context.Context, ri *authority.RevokeOptions) error { @@ -187,7 +187,7 @@ func Test_caHandler_Revoke(t *testing.T) { input: string(input), statusCode: http.StatusInternalServerError, auth: &mockAuthority{ - authorizeSign: func(ott string) ([]provisioner.SignOption, error) { + authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return nil, nil }, revoke: func(ctx context.Context, opts *authority.RevokeOptions) error { @@ -209,7 +209,7 @@ func Test_caHandler_Revoke(t *testing.T) { input: string(input), statusCode: http.StatusForbidden, auth: &mockAuthority{ - authorizeSign: func(ott string) ([]provisioner.SignOption, error) { + authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return nil, nil }, revoke: func(ctx context.Context, opts *authority.RevokeOptions) error { @@ -223,13 +223,13 @@ func Test_caHandler_Revoke(t *testing.T) { for name, _tc := range tests { tc := _tc(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*caHandler) + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("POST", "http://example.com/revoke", strings.NewReader(tc.input)) if tc.tls != nil { req.TLS = tc.tls } w := httptest.NewRecorder() - h.Revoke(logging.NewResponseLogger(w), req) + Revoke(logging.NewResponseLogger(w), req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) diff --git a/api/sign.go b/api/sign.go index b6bfcc8b..f7c3cc5a 100644 --- a/api/sign.go +++ b/api/sign.go @@ -49,7 +49,7 @@ type SignResponse struct { // Sign is an HTTP handler that reads a certificate request and an // one-time-token (ott) from the body and creates a new certificate with the // information in the certificate request. -func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { +func Sign(w http.ResponseWriter, r *http.Request) { var body SignRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -68,13 +68,17 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { TemplateData: body.TemplateData, } - signOpts, err := h.Authority.AuthorizeSign(body.OTT) + ctx := r.Context() + a := mustAuthority(ctx) + + ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) + signOpts, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) return } - certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) + certChain, err := a.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error signing certificate")) return @@ -89,6 +93,6 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, - TLSOptions: h.Authority.GetTLSOptions(), + TLSOptions: a.GetTLSOptions(), }, http.StatusCreated) } diff --git a/api/ssh.go b/api/ssh.go index df96396f..4bd20495 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -250,7 +250,7 @@ type SSHBastionResponse struct { // SSHSign is an HTTP handler that reads an SignSSHRequest with a one-time-token // (ott) from the body and creates a new SSH certificate with the information in // the request. -func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { +func SSHSign(w http.ResponseWriter, r *http.Request) { var body SSHSignRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -289,13 +289,15 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod) ctx = provisioner.NewContextWithToken(ctx, body.OTT) - signOpts, err := h.Authority.Authorize(ctx, body.OTT) + + a := mustAuthority(ctx) + signOpts, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) return } - cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...) + cert, err := a.SignSSH(ctx, publicKey, opts, signOpts...) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate")) return @@ -303,7 +305,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { var addUserCertificate *SSHCertificate if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil { - addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert) + addUserCert, err := a.SignSSHAddUser(ctx, addUserPublicKey, cert) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate")) return @@ -316,7 +318,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { if cr := body.IdentityCSR.CertificateRequest; cr != nil { ctx := authority.NewContextWithSkipTokenReuse(r.Context()) ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) - signOpts, err := h.Authority.Authorize(ctx, body.OTT) + signOpts, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) return @@ -328,7 +330,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { NotAfter: time.Unix(int64(cert.ValidBefore), 0), }) - certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...) + certChain, err := a.Sign(cr, provisioner.SignOptions{}, signOpts...) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate")) return @@ -345,8 +347,9 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { // SSHRoots is an HTTP handler that returns the SSH public keys for user and host // certificates. -func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { - keys, err := h.Authority.GetSSHRoots(r.Context()) +func SSHRoots(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + keys, err := mustAuthority(ctx).GetSSHRoots(ctx) if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -370,8 +373,9 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { // SSHFederation is an HTTP handler that returns the federated SSH public keys // for user and host certificates. -func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { - keys, err := h.Authority.GetSSHFederation(r.Context()) +func SSHFederation(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + keys, err := mustAuthority(ctx).GetSSHFederation(ctx) if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -395,7 +399,7 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { // SSHConfig is an HTTP handler that returns rendered templates for ssh clients // and servers. -func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { +func SSHConfig(w http.ResponseWriter, r *http.Request) { var body SSHConfigRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -406,7 +410,8 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { return } - ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data) + ctx := r.Context() + ts, err := mustAuthority(ctx).GetSSHConfig(ctx, body.Type, body.Data) if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -427,7 +432,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { } // SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not. -func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { +func SSHCheckHost(w http.ResponseWriter, r *http.Request) { var body SSHCheckPrincipalRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -438,7 +443,8 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { return } - exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token) + ctx := r.Context() + exists, err := mustAuthority(ctx).CheckSSHHost(ctx, body.Principal, body.Token) if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -449,13 +455,14 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { } // SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts. -func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { +func SSHGetHosts(w http.ResponseWriter, r *http.Request) { var cert *x509.Certificate if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { cert = r.TLS.PeerCertificates[0] } - hosts, err := h.Authority.GetSSHHosts(r.Context(), cert) + ctx := r.Context() + hosts, err := mustAuthority(ctx).GetSSHHosts(ctx, cert) if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -466,7 +473,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { } // SSHBastion provides returns the bastion configured if any. -func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { +func SSHBastion(w http.ResponseWriter, r *http.Request) { var body SSHBastionRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -477,7 +484,8 @@ func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { return } - bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname) + ctx := r.Context() + bastion, err := mustAuthority(ctx).GetSSHBastion(ctx, body.User, body.Hostname) if err != nil { render.Error(w, errs.InternalServerErr(err)) return diff --git a/api/sshRekey.go b/api/sshRekey.go index 1819428a..6c0a5064 100644 --- a/api/sshRekey.go +++ b/api/sshRekey.go @@ -39,7 +39,7 @@ type SSHRekeyResponse struct { // SSHRekey is an HTTP handler that reads an RekeySSHRequest with a one-time-token // (ott) from the body and creates a new SSH certificate with the information in // the request. -func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { +func SSHRekey(w http.ResponseWriter, r *http.Request) { var body SSHRekeyRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -60,7 +60,9 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod) ctx = provisioner.NewContextWithToken(ctx, body.OTT) - signOpts, err := h.Authority.Authorize(ctx, body.OTT) + + a := mustAuthority(ctx) + signOpts, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) return @@ -71,7 +73,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { return } - newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...) + newCert, err := a.RekeySSH(ctx, oldCert, publicKey, signOpts...) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate")) return @@ -81,7 +83,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { notBefore := time.Unix(int64(oldCert.ValidAfter), 0) notAfter := time.Unix(int64(oldCert.ValidBefore), 0) - identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) + identity, err := renewIdentityCertificate(r, notBefore, notAfter) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate")) return diff --git a/api/sshRenew.go b/api/sshRenew.go index 58f2e525..4e4d0b04 100644 --- a/api/sshRenew.go +++ b/api/sshRenew.go @@ -37,7 +37,7 @@ type SSHRenewResponse struct { // SSHRenew is an HTTP handler that reads an RenewSSHRequest with a one-time-token // (ott) from the body and creates a new SSH certificate with the information in // the request. -func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { +func SSHRenew(w http.ResponseWriter, r *http.Request) { var body SSHRenewRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -52,7 +52,9 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod) ctx = provisioner.NewContextWithToken(ctx, body.OTT) - _, err := h.Authority.Authorize(ctx, body.OTT) + + a := mustAuthority(ctx) + _, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) return @@ -63,7 +65,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { return } - newCert, err := h.Authority.RenewSSH(ctx, oldCert) + newCert, err := a.RenewSSH(ctx, oldCert) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate")) return @@ -73,7 +75,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { notBefore := time.Unix(int64(oldCert.ValidAfter), 0) notAfter := time.Unix(int64(oldCert.ValidBefore), 0) - identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) + identity, err := renewIdentityCertificate(r, notBefore, notAfter) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate")) return @@ -86,7 +88,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { } // renewIdentityCertificate request the client TLS certificate if present. If notBefore and notAfter are passed the -func (h *caHandler) renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) { +func renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { return nil, nil } @@ -106,7 +108,7 @@ func (h *caHandler) renewIdentityCertificate(r *http.Request, notBefore, notAfte cert.NotAfter = notAfter } - certChain, err := h.Authority.Renew(cert) + certChain, err := mustAuthority(r.Context()).Renew(cert) if err != nil { return nil, err } diff --git a/api/sshRevoke.go b/api/sshRevoke.go index a33082cd..d377def9 100644 --- a/api/sshRevoke.go +++ b/api/sshRevoke.go @@ -48,7 +48,7 @@ func (r *SSHRevokeRequest) Validate() (err error) { // Revoke supports handful of different methods that revoke a Certificate. // // NOTE: currently only Passive revocation is supported. -func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { +func SSHRevoke(w http.ResponseWriter, r *http.Request) { var body SSHRevokeRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -68,16 +68,19 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRevokeMethod) + a := mustAuthority(ctx) + // A token indicates that we are using the api via a provisioner token, // otherwise it is assumed that the certificate is revoking itself over mTLS. logOtt(w, body.OTT) - if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { + + if _, err := a.Authorize(ctx, body.OTT); err != nil { render.Error(w, errs.UnauthorizedErr(err)) return } opts.OTT = body.OTT - if err := h.Authority.Revoke(ctx, opts); err != nil { + if err := a.Revoke(ctx, opts); err != nil { render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate")) return } diff --git a/api/ssh_test.go b/api/ssh_test.go index 88a301f5..57dd6775 100644 --- a/api/ssh_test.go +++ b/api/ssh_test.go @@ -251,7 +251,7 @@ func TestSignSSHRequest_Validate(t *testing.T) { } } -func Test_caHandler_SSHSign(t *testing.T) { +func Test_SSHSign(t *testing.T) { user, err := getSignedUserCertificate() assert.FatalError(t, err) host, err := getSignedHostCertificate() @@ -315,8 +315,8 @@ func Test_caHandler_SSHSign(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ - authorizeSign: func(ott string) ([]provisioner.SignOption, error) { + mockMustAuthority(t, &mockAuthority{ + authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return []provisioner.SignOption{}, tt.authErr }, signSSH: func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { @@ -328,11 +328,11 @@ func Test_caHandler_SSHSign(t *testing.T) { sign: func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { return tt.tlsSignCerts, tt.tlsSignErr }, - }).(*caHandler) + }) req := httptest.NewRequest("POST", "http://example.com/ssh/sign", bytes.NewReader(tt.req)) w := httptest.NewRecorder() - h.SSHSign(logging.NewResponseLogger(w), req) + SSHSign(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -353,7 +353,7 @@ func Test_caHandler_SSHSign(t *testing.T) { } } -func Test_caHandler_SSHRoots(t *testing.T) { +func Test_SSHRoots(t *testing.T) { user, err := ssh.NewPublicKey(sshUserKey.Public()) assert.FatalError(t, err) userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) @@ -378,15 +378,15 @@ func Test_caHandler_SSHRoots(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) { return tt.keys, tt.keysErr }, - }).(*caHandler) + }) req := httptest.NewRequest("GET", "http://example.com/ssh/roots", http.NoBody) w := httptest.NewRecorder() - h.SSHRoots(logging.NewResponseLogger(w), req) + SSHRoots(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -407,7 +407,7 @@ func Test_caHandler_SSHRoots(t *testing.T) { } } -func Test_caHandler_SSHFederation(t *testing.T) { +func Test_SSHFederation(t *testing.T) { user, err := ssh.NewPublicKey(sshUserKey.Public()) assert.FatalError(t, err) userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) @@ -432,15 +432,15 @@ func Test_caHandler_SSHFederation(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) { return tt.keys, tt.keysErr }, - }).(*caHandler) + }) req := httptest.NewRequest("GET", "http://example.com/ssh/federation", http.NoBody) w := httptest.NewRecorder() - h.SSHFederation(logging.NewResponseLogger(w), req) + SSHFederation(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -461,7 +461,7 @@ func Test_caHandler_SSHFederation(t *testing.T) { } } -func Test_caHandler_SSHConfig(t *testing.T) { +func Test_SSHConfig(t *testing.T) { userOutput := []templates.Output{ {Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("UserKnownHostsFile /home/user/.step/ssh/known_hosts")}, {Name: "known_host.tpl", Type: templates.File, Comment: "#", Path: "ssh/known_host", Content: []byte("@cert-authority * ecdsa-sha2-nistp256 AAAA...=")}, @@ -492,15 +492,15 @@ func Test_caHandler_SSHConfig(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) { return tt.output, tt.err }, - }).(*caHandler) + }) req := httptest.NewRequest("GET", "http://example.com/ssh/config", strings.NewReader(tt.req)) w := httptest.NewRecorder() - h.SSHConfig(logging.NewResponseLogger(w), req) + SSHConfig(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -521,7 +521,7 @@ func Test_caHandler_SSHConfig(t *testing.T) { } } -func Test_caHandler_SSHCheckHost(t *testing.T) { +func Test_SSHCheckHost(t *testing.T) { tests := []struct { name string req string @@ -539,15 +539,15 @@ func Test_caHandler_SSHCheckHost(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ checkSSHHost: func(ctx context.Context, principal, token string) (bool, error) { return tt.exists, tt.err }, - }).(*caHandler) + }) req := httptest.NewRequest("GET", "http://example.com/ssh/check-host", strings.NewReader(tt.req)) w := httptest.NewRecorder() - h.SSHCheckHost(logging.NewResponseLogger(w), req) + SSHCheckHost(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -568,7 +568,7 @@ func Test_caHandler_SSHCheckHost(t *testing.T) { } } -func Test_caHandler_SSHGetHosts(t *testing.T) { +func Test_SSHGetHosts(t *testing.T) { hosts := []authority.Host{ {HostID: "1", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}}, Hostname: "host1"}, {HostID: "2", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}, {ID: "2", Name: "group", Value: "2"}}, Hostname: "host2"}, @@ -590,15 +590,15 @@ func Test_caHandler_SSHGetHosts(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ getSSHHosts: func(context.Context, *x509.Certificate) ([]authority.Host, error) { return tt.hosts, tt.err }, - }).(*caHandler) + }) req := httptest.NewRequest("GET", "http://example.com/ssh/host", http.NoBody) w := httptest.NewRecorder() - h.SSHGetHosts(logging.NewResponseLogger(w), req) + SSHGetHosts(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -619,7 +619,7 @@ func Test_caHandler_SSHGetHosts(t *testing.T) { } } -func Test_caHandler_SSHBastion(t *testing.T) { +func Test_SSHBastion(t *testing.T) { bastion := &authority.Bastion{ Hostname: "bastion.local", } @@ -645,15 +645,15 @@ func Test_caHandler_SSHBastion(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + mockMustAuthority(t, &mockAuthority{ getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) { return tt.bastion, tt.bastionErr }, - }).(*caHandler) + }) req := httptest.NewRequest("POST", "http://example.com/ssh/bastion", bytes.NewReader(tt.req)) w := httptest.NewRecorder() - h.SSHBastion(logging.NewResponseLogger(w), req) + SSHBastion(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { diff --git a/authority/admin/api/acme.go b/authority/admin/api/acme.go index da491dfe..db393e9a 100644 --- a/authority/admin/api/acme.go +++ b/authority/admin/api/acme.go @@ -33,7 +33,7 @@ type GetExternalAccountKeysResponse struct { // requireEABEnabled is a middleware that ensures ACME EAB is enabled // before serving requests that act on ACME EAB credentials. -func (h *Handler) requireEABEnabled(next http.HandlerFunc) http.HandlerFunc { +func requireEABEnabled(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() prov := linkedca.MustProvisionerFromContext(ctx) @@ -53,32 +53,33 @@ func (h *Handler) requireEABEnabled(next http.HandlerFunc) http.HandlerFunc { } } -type acmeAdminResponderInterface interface { +// ACMEAdminResponder is responsible for writing ACME admin responses +type ACMEAdminResponder interface { GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) } -// ACMEAdminResponder is responsible for writing ACME admin responses -type ACMEAdminResponder struct{} +// acmeAdminResponder implements ACMEAdminResponder. +type acmeAdminResponder struct{} // NewACMEAdminResponder returns a new ACMEAdminResponder -func NewACMEAdminResponder() *ACMEAdminResponder { - return &ACMEAdminResponder{} +func NewACMEAdminResponder() ACMEAdminResponder { + return &acmeAdminResponder{} } // GetExternalAccountKeys writes the response for the EAB keys GET endpoint -func (h *ACMEAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) { +func (h *acmeAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) { render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } // CreateExternalAccountKey writes the response for the EAB key POST endpoint -func (h *ACMEAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) { +func (h *acmeAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) { render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } // DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint -func (h *ACMEAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) { +func (h *acmeAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) { render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } diff --git a/authority/admin/api/acme_test.go b/authority/admin/api/acme_test.go index 3ff32763..6d478145 100644 --- a/authority/admin/api/acme_test.go +++ b/authority/admin/api/acme_test.go @@ -33,6 +33,17 @@ func readProtoJSON(r io.ReadCloser, m proto.Message) error { return protojson.Unmarshal(data, m) } +func mockMustAuthority(t *testing.T, a adminAuthority) { + t.Helper() + fn := mustAuthority + t.Cleanup(func() { + mustAuthority = fn + }) + mustAuthority = func(ctx context.Context) adminAuthority { + return a + } +} + func TestHandler_requireEABEnabled(t *testing.T) { type test struct { ctx context.Context @@ -117,12 +128,9 @@ func TestHandler_requireEABEnabled(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{} - - req := httptest.NewRequest("GET", "/foo", nil) - req = req.WithContext(tc.ctx) + req := httptest.NewRequest("GET", "/foo", nil).WithContext(tc.ctx) w := httptest.NewRecorder() - h.requireEABEnabled(tc.next)(w, req) + requireEABEnabled(tc.next)(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) diff --git a/authority/admin/api/admin.go b/authority/admin/api/admin.go index a033b1a5..c7adced3 100644 --- a/authority/admin/api/admin.go +++ b/authority/admin/api/admin.go @@ -85,10 +85,10 @@ type DeleteResponse struct { } // GetAdmin returns the requested admin, or an error. -func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) { +func GetAdmin(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "id") - adm, ok := h.auth.LoadAdminByID(id) + adm, ok := mustAuthority(r.Context()).LoadAdminByID(id) if !ok { render.Error(w, admin.NewError(admin.ErrorNotFoundType, "admin %s not found", id)) @@ -98,7 +98,7 @@ func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) { } // GetAdmins returns a segment of admins associated with the authority. -func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { +func GetAdmins(w http.ResponseWriter, r *http.Request) { cursor, limit, err := api.ParseCursor(r) if err != nil { render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, @@ -106,7 +106,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { return } - admins, nextCursor, err := h.auth.GetAdmins(cursor, limit) + admins, nextCursor, err := mustAuthority(r.Context()).GetAdmins(cursor, limit) if err != nil { render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins")) return @@ -118,7 +118,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { } // CreateAdmin creates a new admin. -func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { +func CreateAdmin(w http.ResponseWriter, r *http.Request) { var body CreateAdminRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) @@ -130,7 +130,8 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { return } - p, err := h.auth.LoadProvisionerByName(body.Provisioner) + auth := mustAuthority(r.Context()) + p, err := auth.LoadProvisionerByName(body.Provisioner) if err != nil { render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner)) return @@ -141,7 +142,7 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { Type: body.Type, } // Store to authority collection. - if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil { + if err := auth.StoreAdmin(r.Context(), adm, p); err != nil { render.Error(w, admin.WrapErrorISE(err, "error storing admin")) return } @@ -150,10 +151,10 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { } // DeleteAdmin deletes admin. -func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) { +func DeleteAdmin(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "id") - if err := h.auth.RemoveAdmin(r.Context(), id); err != nil { + if err := mustAuthority(r.Context()).RemoveAdmin(r.Context(), id); err != nil { render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id)) return } @@ -162,7 +163,7 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) { } // UpdateAdmin updates an existing admin. -func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { +func UpdateAdmin(w http.ResponseWriter, r *http.Request) { var body UpdateAdminRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) @@ -175,8 +176,8 @@ func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { } id := chi.URLParam(r, "id") - - adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type}) + auth := mustAuthority(r.Context()) + adm, err := auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type}) if err != nil { render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id)) return diff --git a/authority/admin/api/admin_test.go b/authority/admin/api/admin_test.go index cc77ef77..ecb95244 100644 --- a/authority/admin/api/admin_test.go +++ b/authority/admin/api/admin_test.go @@ -352,14 +352,11 @@ func TestHandler_GetAdmin(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } - + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetAdmin(w, req) + GetAdmin(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -491,13 +488,10 @@ func TestHandler_GetAdmins(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } - + mockMustAuthority(t, tc.auth) req := tc.req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetAdmins(w, req) + GetAdmins(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -675,13 +669,11 @@ func TestHandler_CreateAdmin(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.CreateAdmin(w, req) + CreateAdmin(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -767,13 +759,11 @@ func TestHandler_DeleteAdmin(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("DELETE", "/foo", nil) // chi routing is prepared in test setup req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.DeleteAdmin(w, req) + DeleteAdmin(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -912,13 +902,11 @@ func TestHandler_UpdateAdmin(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.UpdateAdmin(w, req) + UpdateAdmin(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) diff --git a/authority/admin/api/handler.go b/authority/admin/api/handler.go index eb52ad58..1e5919ce 100644 --- a/authority/admin/api/handler.go +++ b/authority/admin/api/handler.go @@ -1,50 +1,58 @@ package api import ( + "context" "net/http" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/admin" ) // Handler is the Admin API request handler. type Handler struct { - adminDB admin.DB - auth adminAuthority - acmeDB acme.DB - acmeResponder acmeAdminResponderInterface - policyResponder policyAdminResponderInterface + acmeResponder ACMEAdminResponder + policyResponder PolicyAdminResponder +} + +// Route traffic and implement the Router interface. +// +// Deprecated: use Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder) +func (h *Handler) Route(r api.Router) { + Route(r, h.acmeResponder, h.policyResponder) } // NewHandler returns a new Authority Config Handler. -func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder acmeAdminResponderInterface, policyResponder policyAdminResponderInterface) api.RouterHandler { +// +// Deprecated: use Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder) +func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder) api.RouterHandler { return &Handler{ - auth: auth, - adminDB: adminDB, - acmeDB: acmeDB, acmeResponder: acmeResponder, policyResponder: policyResponder, } } -// Route traffic and implement the Router interface. -func (h *Handler) Route(r api.Router) { +var mustAuthority = func(ctx context.Context) adminAuthority { + return authority.MustFromContext(ctx) +} +// Route traffic and implement the Router interface. +func Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder) { authnz := func(next http.HandlerFunc) http.HandlerFunc { - return h.extractAuthorizeTokenAdmin(h.requireAPIEnabled(next)) + return extractAuthorizeTokenAdmin(requireAPIEnabled(next)) } enabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc { - return h.checkAction(next, true) + return checkAction(next, true) } disabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc { - return h.checkAction(next, false) + return checkAction(next, false) } acmeEABMiddleware := func(next http.HandlerFunc) http.HandlerFunc { - return authnz(h.loadProvisionerByName(h.requireEABEnabled(next))) + return authnz(loadProvisionerByName(requireEABEnabled(next))) } authorityPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc { @@ -52,53 +60,58 @@ func (h *Handler) Route(r api.Router) { } provisionerPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc { - return authnz(disabledInStandalone(h.loadProvisionerByName(next))) + return authnz(disabledInStandalone(loadProvisionerByName(next))) } acmePolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc { - return authnz(disabledInStandalone(h.loadProvisionerByName(h.requireEABEnabled(h.loadExternalAccountKey(next))))) + return authnz(disabledInStandalone(loadProvisionerByName(requireEABEnabled(loadExternalAccountKey(next))))) } // Provisioners - r.MethodFunc("GET", "/provisioners/{name}", authnz(h.GetProvisioner)) - r.MethodFunc("GET", "/provisioners", authnz(h.GetProvisioners)) - r.MethodFunc("POST", "/provisioners", authnz(h.CreateProvisioner)) - r.MethodFunc("PUT", "/provisioners/{name}", authnz(h.UpdateProvisioner)) - r.MethodFunc("DELETE", "/provisioners/{name}", authnz(h.DeleteProvisioner)) + r.MethodFunc("GET", "/provisioners/{name}", authnz(GetProvisioner)) + r.MethodFunc("GET", "/provisioners", authnz(GetProvisioners)) + r.MethodFunc("POST", "/provisioners", authnz(CreateProvisioner)) + r.MethodFunc("PUT", "/provisioners/{name}", authnz(UpdateProvisioner)) + r.MethodFunc("DELETE", "/provisioners/{name}", authnz(DeleteProvisioner)) // Admins - r.MethodFunc("GET", "/admins/{id}", authnz(h.GetAdmin)) - r.MethodFunc("GET", "/admins", authnz(h.GetAdmins)) - r.MethodFunc("POST", "/admins", authnz(h.CreateAdmin)) - r.MethodFunc("PATCH", "/admins/{id}", authnz(h.UpdateAdmin)) - r.MethodFunc("DELETE", "/admins/{id}", authnz(h.DeleteAdmin)) - - // ACME External Account Binding Keys - r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", acmeEABMiddleware(h.acmeResponder.GetExternalAccountKeys)) - r.MethodFunc("GET", "/acme/eab/{provisionerName}", acmeEABMiddleware(h.acmeResponder.GetExternalAccountKeys)) - r.MethodFunc("POST", "/acme/eab/{provisionerName}", acmeEABMiddleware(h.acmeResponder.CreateExternalAccountKey)) - r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", acmeEABMiddleware(h.acmeResponder.DeleteExternalAccountKey)) - - // Policy - Authority - r.MethodFunc("GET", "/policy", authorityPolicyMiddleware(h.policyResponder.GetAuthorityPolicy)) - r.MethodFunc("POST", "/policy", authorityPolicyMiddleware(h.policyResponder.CreateAuthorityPolicy)) - r.MethodFunc("PUT", "/policy", authorityPolicyMiddleware(h.policyResponder.UpdateAuthorityPolicy)) - r.MethodFunc("DELETE", "/policy", authorityPolicyMiddleware(h.policyResponder.DeleteAuthorityPolicy)) - - // Policy - Provisioner - r.MethodFunc("GET", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(h.policyResponder.GetProvisionerPolicy)) - r.MethodFunc("POST", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(h.policyResponder.CreateProvisionerPolicy)) - r.MethodFunc("PUT", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(h.policyResponder.UpdateProvisionerPolicy)) - r.MethodFunc("DELETE", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(h.policyResponder.DeleteProvisionerPolicy)) - - // Policy - ACME Account - r.MethodFunc("GET", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(h.policyResponder.GetACMEAccountPolicy)) - r.MethodFunc("GET", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(h.policyResponder.GetACMEAccountPolicy)) - r.MethodFunc("POST", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(h.policyResponder.CreateACMEAccountPolicy)) - r.MethodFunc("POST", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(h.policyResponder.CreateACMEAccountPolicy)) - r.MethodFunc("PUT", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(h.policyResponder.UpdateACMEAccountPolicy)) - r.MethodFunc("PUT", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(h.policyResponder.UpdateACMEAccountPolicy)) - r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(h.policyResponder.DeleteACMEAccountPolicy)) - r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(h.policyResponder.DeleteACMEAccountPolicy)) + r.MethodFunc("GET", "/admins/{id}", authnz(GetAdmin)) + r.MethodFunc("GET", "/admins", authnz(GetAdmins)) + r.MethodFunc("POST", "/admins", authnz(CreateAdmin)) + r.MethodFunc("PATCH", "/admins/{id}", authnz(UpdateAdmin)) + r.MethodFunc("DELETE", "/admins/{id}", authnz(DeleteAdmin)) + + // ACME responder + if acmeResponder != nil { + // ACME External Account Binding Keys + r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", acmeEABMiddleware(acmeResponder.GetExternalAccountKeys)) + r.MethodFunc("GET", "/acme/eab/{provisionerName}", acmeEABMiddleware(acmeResponder.GetExternalAccountKeys)) + r.MethodFunc("POST", "/acme/eab/{provisionerName}", acmeEABMiddleware(acmeResponder.CreateExternalAccountKey)) + r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", acmeEABMiddleware(acmeResponder.DeleteExternalAccountKey)) + } + // Policy responder + if policyResponder != nil { + // Policy - Authority + r.MethodFunc("GET", "/policy", authorityPolicyMiddleware(policyResponder.GetAuthorityPolicy)) + r.MethodFunc("POST", "/policy", authorityPolicyMiddleware(policyResponder.CreateAuthorityPolicy)) + r.MethodFunc("PUT", "/policy", authorityPolicyMiddleware(policyResponder.UpdateAuthorityPolicy)) + r.MethodFunc("DELETE", "/policy", authorityPolicyMiddleware(policyResponder.DeleteAuthorityPolicy)) + + // Policy - Provisioner + r.MethodFunc("GET", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.GetProvisionerPolicy)) + r.MethodFunc("POST", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.CreateProvisionerPolicy)) + r.MethodFunc("PUT", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.UpdateProvisionerPolicy)) + r.MethodFunc("DELETE", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.DeleteProvisionerPolicy)) + + // Policy - ACME Account + r.MethodFunc("GET", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.GetACMEAccountPolicy)) + r.MethodFunc("GET", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.GetACMEAccountPolicy)) + r.MethodFunc("POST", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.CreateACMEAccountPolicy)) + r.MethodFunc("POST", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.CreateACMEAccountPolicy)) + r.MethodFunc("PUT", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.UpdateACMEAccountPolicy)) + r.MethodFunc("PUT", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.UpdateACMEAccountPolicy)) + r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.DeleteACMEAccountPolicy)) + r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.DeleteACMEAccountPolicy)) + } } diff --git a/authority/admin/api/middleware.go b/authority/admin/api/middleware.go index 24adfdf2..780cfb65 100644 --- a/authority/admin/api/middleware.go +++ b/authority/admin/api/middleware.go @@ -17,11 +17,10 @@ import ( // requireAPIEnabled is a middleware that ensures the Administration API // is enabled before servicing requests. -func (h *Handler) requireAPIEnabled(next http.HandlerFunc) http.HandlerFunc { +func requireAPIEnabled(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - if !h.auth.IsAdminAPIEnabled() { - render.Error(w, admin.NewError(admin.ErrorNotImplementedType, - "administration API not enabled")) + if !mustAuthority(r.Context()).IsAdminAPIEnabled() { + render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "administration API not enabled")) return } next(w, r) @@ -29,7 +28,7 @@ func (h *Handler) requireAPIEnabled(next http.HandlerFunc) http.HandlerFunc { } // extractAuthorizeTokenAdmin is a middleware that extracts and caches the bearer token. -func (h *Handler) extractAuthorizeTokenAdmin(next http.HandlerFunc) http.HandlerFunc { +func extractAuthorizeTokenAdmin(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { tok := r.Header.Get("Authorization") @@ -39,36 +38,39 @@ func (h *Handler) extractAuthorizeTokenAdmin(next http.HandlerFunc) http.Handler return } - adm, err := h.auth.AuthorizeAdminToken(r, tok) + ctx := r.Context() + adm, err := mustAuthority(ctx).AuthorizeAdminToken(r, tok) if err != nil { render.Error(w, err) return } - ctx := linkedca.NewContextWithAdmin(r.Context(), adm) + ctx = linkedca.NewContextWithAdmin(ctx, adm) next(w, r.WithContext(ctx)) } } // loadProvisionerByName is a middleware that searches for a provisioner // by name and stores it in the context. -func (h *Handler) loadProvisionerByName(next http.HandlerFunc) http.HandlerFunc { +func loadProvisionerByName(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - - ctx := r.Context() - name := chi.URLParam(r, "provisionerName") var ( p provisioner.Interface err error ) + ctx := r.Context() + auth := mustAuthority(ctx) + adminDB := admin.MustFromContext(ctx) + name := chi.URLParam(r, "provisionerName") + // TODO(hs): distinguish 404 vs. 500 - if p, err = h.auth.LoadProvisionerByName(name); err != nil { + if p, err = auth.LoadProvisionerByName(name); err != nil { render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) return } - prov, err := h.adminDB.GetProvisioner(ctx, p.GetID()) + prov, err := adminDB.GetProvisioner(ctx, p.GetID()) if err != nil { render.Error(w, admin.WrapErrorISE(err, "error retrieving provisioner %s", name)) return @@ -80,9 +82,8 @@ func (h *Handler) loadProvisionerByName(next http.HandlerFunc) http.HandlerFunc } // checkAction checks if an action is supported in standalone or not -func (h *Handler) checkAction(next http.HandlerFunc, supportedInStandalone bool) http.HandlerFunc { +func checkAction(next http.HandlerFunc, supportedInStandalone bool) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - // actions allowed in standalone mode are always supported if supportedInStandalone { next(w, r) @@ -91,7 +92,7 @@ func (h *Handler) checkAction(next http.HandlerFunc, supportedInStandalone bool) // when an action is not supported in standalone mode and when // using a nosql.DB backend, actions are not supported - if _, ok := h.adminDB.(*nosql.DB); ok { + if _, ok := admin.MustFromContext(r.Context()).(*nosql.DB); ok { render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "operation not supported in standalone mode")) return @@ -104,10 +105,11 @@ func (h *Handler) checkAction(next http.HandlerFunc, supportedInStandalone bool) // loadExternalAccountKey is a middleware that searches for an ACME // External Account Key by reference or keyID and stores it in the context. -func (h *Handler) loadExternalAccountKey(next http.HandlerFunc) http.HandlerFunc { +func loadExternalAccountKey(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() prov := linkedca.MustProvisionerFromContext(ctx) + acmeDB := acme.MustDatabaseFromContext(ctx) reference := chi.URLParam(r, "reference") keyID := chi.URLParam(r, "keyID") @@ -118,9 +120,9 @@ func (h *Handler) loadExternalAccountKey(next http.HandlerFunc) http.HandlerFunc ) if keyID != "" { - eak, err = h.acmeDB.GetExternalAccountKey(ctx, prov.GetId(), keyID) + eak, err = acmeDB.GetExternalAccountKey(ctx, prov.GetId(), keyID) } else { - eak, err = h.acmeDB.GetExternalAccountKeyByReference(ctx, prov.GetId(), reference) + eak, err = acmeDB.GetExternalAccountKeyByReference(ctx, prov.GetId(), reference) } if err != nil { diff --git a/authority/admin/api/middleware_test.go b/authority/admin/api/middleware_test.go index 42caed9a..4684b047 100644 --- a/authority/admin/api/middleware_test.go +++ b/authority/admin/api/middleware_test.go @@ -71,13 +71,11 @@ func TestHandler_requireAPIEnabled(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.requireAPIEnabled(tc.next)(w, req) + requireAPIEnabled(tc.next)(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -196,13 +194,10 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } - + mockMustAuthority(t, tc.auth) req := tc.req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.extractAuthorizeTokenAdmin(tc.next)(w, req) + extractAuthorizeTokenAdmin(tc.next)(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -251,6 +246,7 @@ func TestHandler_loadProvisionerByName(t *testing.T) { return test{ ctx: ctx, auth: auth, + adminDB: &admin.MockDB{}, statusCode: 500, err: err, } @@ -326,16 +322,13 @@ func TestHandler_loadProvisionerByName(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - adminDB: tc.adminDB, - } - + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.loadProvisionerByName(tc.next)(w, req) + loadProvisionerByName(tc.next)(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -405,14 +398,10 @@ func TestHandler_checkAction(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - - adminDB: tc.adminDB, - } - - req := httptest.NewRequest("GET", "/foo", nil) + ctx := admin.NewContext(context.Background(), tc.adminDB) + req := httptest.NewRequest("GET", "/foo", nil).WithContext(ctx) w := httptest.NewRecorder() - h.checkAction(tc.next, tc.supportedInStandalone)(w, req) + checkAction(tc.next, tc.supportedInStandalone)(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -653,14 +642,11 @@ func TestHandler_loadExternalAccountKey(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - acmeDB: tc.acmeDB, - } - + ctx := acme.NewDatabaseContext(tc.ctx, tc.acmeDB) req := httptest.NewRequest("GET", "/foo", nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.loadExternalAccountKey(tc.next)(w, req) + loadExternalAccountKey(tc.next)(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) diff --git a/authority/admin/api/policy.go b/authority/admin/api/policy.go index 6af1104a..a478c83c 100644 --- a/authority/admin/api/policy.go +++ b/authority/admin/api/policy.go @@ -1,6 +1,7 @@ package api import ( + "context" "errors" "net/http" @@ -14,7 +15,9 @@ import ( "github.com/smallstep/certificates/authority/policy" ) -type policyAdminResponderInterface interface { +// PolicyAdminResponder is the interface responsible for writing ACME admin +// responses. +type PolicyAdminResponder interface { GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) @@ -29,39 +32,24 @@ type policyAdminResponderInterface interface { DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) } -// PolicyAdminResponder is responsible for writing ACME admin responses -type PolicyAdminResponder struct { - auth adminAuthority - adminDB admin.DB - acmeDB acme.DB - isLinkedCA bool -} - -// NewACMEAdminResponder returns a new ACMEAdminResponder -func NewPolicyAdminResponder(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB) *PolicyAdminResponder { - - var isLinkedCA bool - if a, ok := adminDB.(interface{ IsLinkedCA() bool }); ok { - isLinkedCA = a.IsLinkedCA() - } +// policyAdminResponder implements PolicyAdminResponder. +type policyAdminResponder struct{} - return &PolicyAdminResponder{ - auth: auth, - adminDB: adminDB, - acmeDB: acmeDB, - isLinkedCA: isLinkedCA, - } +// NewACMEAdminResponder returns a new PolicyAdminResponder. +func NewPolicyAdminResponder() PolicyAdminResponder { + return &policyAdminResponder{} } // GetAuthorityPolicy handles the GET /admin/authority/policy request -func (par *PolicyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) { - - if err := par.blockLinkedCA(); err != nil { +func (par *policyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { render.Error(w, err) return } - authorityPolicy, err := par.auth.GetAuthorityPolicy(r.Context()) + auth := mustAuthority(ctx) + authorityPolicy, err := auth.GetAuthorityPolicy(r.Context()) if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) { render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy")) return @@ -76,15 +64,15 @@ func (par *PolicyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *ht } // CreateAuthorityPolicy handles the POST /admin/authority/policy request -func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) { - - if err := par.blockLinkedCA(); err != nil { +func (par *policyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { render.Error(w, err) return } - ctx := r.Context() - authorityPolicy, err := par.auth.GetAuthorityPolicy(ctx) + auth := mustAuthority(ctx) + authorityPolicy, err := auth.GetAuthorityPolicy(ctx) if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) { render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy")) @@ -113,7 +101,7 @@ func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r adm := linkedca.MustAdminFromContext(ctx) var createdPolicy *linkedca.Policy - if createdPolicy, err = par.auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil { + if createdPolicy, err = auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil { if isBadRequest(err) { render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy")) return @@ -127,15 +115,15 @@ func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r } // UpdateAuthorityPolicy handles the PUT /admin/authority/policy request -func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) { - - if err := par.blockLinkedCA(); err != nil { +func (par *policyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { render.Error(w, err) return } - ctx := r.Context() - authorityPolicy, err := par.auth.GetAuthorityPolicy(ctx) + auth := mustAuthority(ctx) + authorityPolicy, err := auth.GetAuthorityPolicy(ctx) if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) { render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy")) @@ -163,7 +151,7 @@ func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r adm := linkedca.MustAdminFromContext(ctx) var updatedPolicy *linkedca.Policy - if updatedPolicy, err = par.auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil { + if updatedPolicy, err = auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil { if isBadRequest(err) { render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy")) return @@ -177,15 +165,15 @@ func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r } // DeleteAuthorityPolicy handles the DELETE /admin/authority/policy request -func (par *PolicyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) { - - if err := par.blockLinkedCA(); err != nil { +func (par *policyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { render.Error(w, err) return } - ctx := r.Context() - authorityPolicy, err := par.auth.GetAuthorityPolicy(ctx) + auth := mustAuthority(ctx) + authorityPolicy, err := auth.GetAuthorityPolicy(ctx) if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) { render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy")) @@ -197,7 +185,7 @@ func (par *PolicyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r return } - if err := par.auth.RemoveAuthorityPolicy(ctx); err != nil { + if err := auth.RemoveAuthorityPolicy(ctx); err != nil { render.Error(w, admin.WrapErrorISE(err, "error deleting authority policy")) return } @@ -206,15 +194,14 @@ func (par *PolicyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r } // GetProvisionerPolicy handles the GET /admin/provisioners/{name}/policy request -func (par *PolicyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) { - - if err := par.blockLinkedCA(); err != nil { +func (par *policyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { render.Error(w, err) return } - prov := linkedca.MustProvisionerFromContext(r.Context()) - + prov := linkedca.MustProvisionerFromContext(ctx) provisionerPolicy := prov.GetPolicy() if provisionerPolicy == nil { render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) @@ -225,16 +212,14 @@ func (par *PolicyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r * } // CreateProvisionerPolicy handles the POST /admin/provisioners/{name}/policy request -func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) { - - if err := par.blockLinkedCA(); err != nil { +func (par *policyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { render.Error(w, err) return } - ctx := r.Context() prov := linkedca.MustProvisionerFromContext(ctx) - provisionerPolicy := prov.GetPolicy() if provisionerPolicy != nil { adminErr := admin.NewError(admin.ErrorConflictType, "provisioner %s already has a policy", prov.Name) @@ -256,8 +241,8 @@ func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, } prov.Policy = newPolicy - - if err := par.auth.UpdateProvisioner(ctx, prov); err != nil { + auth := mustAuthority(ctx) + if err := auth.UpdateProvisioner(ctx, prov); err != nil { if isBadRequest(err) { render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner policy")) return @@ -271,16 +256,14 @@ func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, } // UpdateProvisionerPolicy handles the PUT /admin/provisioners/{name}/policy request -func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) { - - if err := par.blockLinkedCA(); err != nil { +func (par *policyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { render.Error(w, err) return } - ctx := r.Context() prov := linkedca.MustProvisionerFromContext(ctx) - provisionerPolicy := prov.GetPolicy() if provisionerPolicy == nil { render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) @@ -301,7 +284,8 @@ func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, } prov.Policy = newPolicy - if err := par.auth.UpdateProvisioner(ctx, prov); err != nil { + auth := mustAuthority(ctx) + if err := auth.UpdateProvisioner(ctx, prov); err != nil { if isBadRequest(err) { render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner policy")) return @@ -315,16 +299,14 @@ func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, } // DeleteProvisionerPolicy handles the DELETE /admin/provisioners/{name}/policy request -func (par *PolicyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) { - - if err := par.blockLinkedCA(); err != nil { +func (par *policyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { render.Error(w, err) return } - ctx := r.Context() prov := linkedca.MustProvisionerFromContext(ctx) - if prov.Policy == nil { render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) return @@ -333,7 +315,8 @@ func (par *PolicyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, // remove the policy prov.Policy = nil - if err := par.auth.UpdateProvisioner(ctx, prov); err != nil { + auth := mustAuthority(ctx) + if err := auth.UpdateProvisioner(ctx, prov); err != nil { render.Error(w, admin.WrapErrorISE(err, "error deleting provisioner policy")) return } @@ -341,16 +324,14 @@ func (par *PolicyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK) } -func (par *PolicyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { - - if err := par.blockLinkedCA(); err != nil { +func (par *policyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { render.Error(w, err) return } - ctx := r.Context() eak := linkedca.MustExternalAccountKeyFromContext(ctx) - eakPolicy := eak.GetPolicy() if eakPolicy == nil { render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) @@ -360,17 +341,15 @@ func (par *PolicyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r * render.ProtoJSONStatus(w, eakPolicy, http.StatusOK) } -func (par *PolicyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { - - if err := par.blockLinkedCA(); err != nil { +func (par *policyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { render.Error(w, err) return } - ctx := r.Context() prov := linkedca.MustProvisionerFromContext(ctx) eak := linkedca.MustExternalAccountKeyFromContext(ctx) - eakPolicy := eak.GetPolicy() if eakPolicy != nil { adminErr := admin.NewError(admin.ErrorConflictType, "ACME EAK %s already has a policy", eak.Id) @@ -394,7 +373,8 @@ func (par *PolicyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, eak.Policy = newPolicy acmeEAK := linkedEAKToCertificates(eak) - if err := par.acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { + acmeDB := acme.MustDatabaseFromContext(ctx) + if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { render.Error(w, admin.WrapErrorISE(err, "error creating ACME EAK policy")) return } @@ -402,17 +382,15 @@ func (par *PolicyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, render.ProtoJSONStatus(w, newPolicy, http.StatusCreated) } -func (par *PolicyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { - - if err := par.blockLinkedCA(); err != nil { +func (par *policyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { render.Error(w, err) return } - ctx := r.Context() prov := linkedca.MustProvisionerFromContext(ctx) eak := linkedca.MustExternalAccountKeyFromContext(ctx) - eakPolicy := eak.GetPolicy() if eakPolicy == nil { render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) @@ -434,7 +412,8 @@ func (par *PolicyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, eak.Policy = newPolicy acmeEAK := linkedEAKToCertificates(eak) - if err := par.acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { + acmeDB := acme.MustDatabaseFromContext(ctx) + if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { render.Error(w, admin.WrapErrorISE(err, "error updating ACME EAK policy")) return } @@ -442,17 +421,15 @@ func (par *PolicyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, render.ProtoJSONStatus(w, newPolicy, http.StatusOK) } -func (par *PolicyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { - - if err := par.blockLinkedCA(); err != nil { +func (par *policyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if err := blockLinkedCA(ctx); err != nil { render.Error(w, err) return } - ctx := r.Context() prov := linkedca.MustProvisionerFromContext(ctx) eak := linkedca.MustExternalAccountKeyFromContext(ctx) - eakPolicy := eak.GetPolicy() if eakPolicy == nil { render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) @@ -463,7 +440,8 @@ func (par *PolicyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, eak.Policy = nil acmeEAK := linkedEAKToCertificates(eak) - if err := par.acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { + acmeDB := acme.MustDatabaseFromContext(ctx) + if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { render.Error(w, admin.WrapErrorISE(err, "error deleting ACME EAK policy")) return } @@ -472,9 +450,10 @@ func (par *PolicyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, } // blockLinkedCA blocks all API operations on linked deployments -func (par *PolicyAdminResponder) blockLinkedCA() error { +func blockLinkedCA(ctx context.Context) error { // temporary blocking linked deployments - if par.isLinkedCA { + adminDB := admin.MustFromContext(ctx) + if a, ok := adminDB.(interface{ IsLinkedCA() bool }); ok && a.IsLinkedCA() { return admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") } return nil diff --git a/authority/admin/api/policy_test.go b/authority/admin/api/policy_test.go index 1e70db52..1ec88fb6 100644 --- a/authority/admin/api/policy_test.go +++ b/authority/admin/api/policy_test.go @@ -109,7 +109,8 @@ func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) { err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy") err.Message = "error retrieving authority policy: force" return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorServerInternalType, "force") @@ -124,7 +125,8 @@ func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) { err := admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist") err.Message = "authority policy does not exist" return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorNotFoundType, "not found") @@ -179,7 +181,8 @@ func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) { }, } return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return policy, nil @@ -234,11 +237,12 @@ func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - - par := NewPolicyAdminResponder(tc.auth, tc.adminDB, nil) + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + par := NewPolicyAdminResponder() req := httptest.NewRequest("GET", "/foo", nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() par.GetAuthorityPolicy(w, req) @@ -301,7 +305,8 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) { err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy") err.Message = "error retrieving authority policy: force" return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorServerInternalType, "force") @@ -316,7 +321,8 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) { err := admin.NewError(admin.ErrorConflictType, "authority already has a policy") err.Message = "authority already has a policy" return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return &linkedca.Policy{}, nil @@ -332,7 +338,8 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) { adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" body := []byte("{?}") return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorNotFoundType, "not found") @@ -358,7 +365,8 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) { } }`) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorNotFoundType, "not found") @@ -509,11 +517,13 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - - par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB) + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() par.CreateAuthorityPolicy(w, req) @@ -586,7 +596,8 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) { err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy") err.Message = "error retrieving authority policy: force" return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorServerInternalType, "force") @@ -602,7 +613,8 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) { err.Message = "authority policy does not exist" err.Status = http.StatusNotFound return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, nil @@ -625,7 +637,8 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) { adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" body := []byte("{?}") return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return policy, nil @@ -658,7 +671,8 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) { } }`) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return policy, nil @@ -809,11 +823,13 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - - par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB) + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() par.UpdateAuthorityPolicy(w, req) @@ -886,7 +902,8 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) { err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy") err.Message = "error retrieving authority policy: force" return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorServerInternalType, "force") @@ -902,7 +919,8 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) { err.Message = "authority policy does not exist" err.Status = http.StatusNotFound return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, nil @@ -924,7 +942,8 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) { err := admin.NewErrorISE("error deleting authority policy: force") err.Message = "error deleting authority policy: force" return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return policy, nil @@ -947,7 +966,8 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) { } ctx := context.Background() return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return policy, nil @@ -963,11 +983,13 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - - par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB) + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() par.DeleteAuthorityPolicy(w, req) @@ -1033,6 +1055,7 @@ func TestPolicyAdminResponder_GetProvisionerPolicy(t *testing.T) { err.Message = "provisioner policy does not exist" return test{ ctx: ctx, + adminDB: &admin.MockDB{}, err: err, statusCode: 404, } @@ -1085,7 +1108,8 @@ func TestPolicyAdminResponder_GetProvisionerPolicy(t *testing.T) { } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, response: &testPolicyResponse{ X509: &testX509Policy{ Allow: &testX509Names{ @@ -1135,11 +1159,13 @@ func TestPolicyAdminResponder_GetProvisionerPolicy(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - - par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB) + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() req := httptest.NewRequest("GET", "/foo", nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() par.GetProvisionerPolicy(w, req) @@ -1214,6 +1240,7 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) { err.Message = "provisioner provName already has a policy" return test{ ctx: ctx, + adminDB: &admin.MockDB{}, err: err, statusCode: 409, } @@ -1228,6 +1255,7 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) { body := []byte("{?}") return test{ ctx: ctx, + adminDB: &admin.MockDB{}, body: body, err: adminErr, statusCode: 400, @@ -1251,7 +1279,8 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) { } }`) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorNotFoundType, "not found") @@ -1283,7 +1312,8 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) { body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return &authority.PolicyError{ @@ -1318,7 +1348,8 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) { body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return &authority.PolicyError{ @@ -1351,7 +1382,8 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) { body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return nil @@ -1372,11 +1404,12 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - - par := NewPolicyAdminResponder(tc.auth, tc.adminDB, nil) + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() par.CreateProvisionerPolicy(w, req) @@ -1452,6 +1485,7 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) { err.Message = "provisioner policy does not exist" return test{ ctx: ctx, + adminDB: &admin.MockDB{}, err: err, statusCode: 404, } @@ -1474,6 +1508,7 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) { body := []byte("{?}") return test{ ctx: ctx, + adminDB: &admin.MockDB{}, body: body, err: adminErr, statusCode: 400, @@ -1505,7 +1540,8 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) { } }`) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { return nil, admin.NewError(admin.ErrorNotFoundType, "not found") @@ -1538,7 +1574,8 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) { body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return &authority.PolicyError{ @@ -1574,7 +1611,8 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) { body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return &authority.PolicyError{ @@ -1608,7 +1646,8 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) { body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return nil @@ -1629,11 +1668,12 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - - par := NewPolicyAdminResponder(tc.auth, tc.adminDB, nil) + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() par.UpdateProvisionerPolicy(w, req) @@ -1710,6 +1750,7 @@ func TestPolicyAdminResponder_DeleteProvisionerPolicy(t *testing.T) { err.Message = "provisioner policy does not exist" return test{ ctx: ctx, + adminDB: &admin.MockDB{}, err: err, statusCode: 404, } @@ -1723,7 +1764,8 @@ func TestPolicyAdminResponder_DeleteProvisionerPolicy(t *testing.T) { err := admin.NewErrorISE("error deleting provisioner policy: force") err.Message = "error deleting provisioner policy: force" return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return errors.New("force") @@ -1740,7 +1782,8 @@ func TestPolicyAdminResponder_DeleteProvisionerPolicy(t *testing.T) { } ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, auth: &mockAdminAuthority{ MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { return nil @@ -1753,11 +1796,13 @@ func TestPolicyAdminResponder_DeleteProvisionerPolicy(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - - par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB) + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() par.DeleteProvisionerPolicy(w, req) @@ -1828,6 +1873,7 @@ func TestPolicyAdminResponder_GetACMEAccountPolicy(t *testing.T) { err.Message = "ACME EAK policy does not exist" return test{ ctx: ctx, + adminDB: &admin.MockDB{}, err: err, statusCode: 404, } @@ -1885,7 +1931,8 @@ func TestPolicyAdminResponder_GetACMEAccountPolicy(t *testing.T) { ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, response: &testPolicyResponse{ X509: &testX509Policy{ Allow: &testX509Names{ @@ -1935,11 +1982,12 @@ func TestPolicyAdminResponder_GetACMEAccountPolicy(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - - par := NewPolicyAdminResponder(nil, tc.adminDB, tc.acmeDB) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() req := httptest.NewRequest("GET", "/foo", nil) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() par.GetACMEAccountPolicy(w, req) @@ -2018,6 +2066,7 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) { err.Message = "ACME EAK eakID already has a policy" return test{ ctx: ctx, + adminDB: &admin.MockDB{}, err: err, statusCode: 409, } @@ -2036,6 +2085,7 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) { body := []byte("{?}") return test{ ctx: ctx, + adminDB: &admin.MockDB{}, body: body, err: adminErr, statusCode: 400, @@ -2064,6 +2114,7 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) { }`) return test{ ctx: ctx, + adminDB: &admin.MockDB{}, body: body, err: adminErr, statusCode: 400, @@ -2091,7 +2142,8 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) { body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, acmeDB: &acme.MockDB{ MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { assert.Equal(t, "provID", provisionerID) @@ -2124,7 +2176,8 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) { body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, acmeDB: &acme.MockDB{ MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { assert.Equal(t, "provID", provisionerID) @@ -2147,11 +2200,12 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - - par := NewPolicyAdminResponder(nil, tc.adminDB, tc.acmeDB) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() par.CreateACMEAccountPolicy(w, req) @@ -2231,6 +2285,7 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) { err.Message = "ACME EAK policy does not exist" return test{ ctx: ctx, + adminDB: &admin.MockDB{}, err: err, statusCode: 404, } @@ -2257,6 +2312,7 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) { body := []byte("{?}") return test{ ctx: ctx, + adminDB: &admin.MockDB{}, body: body, err: adminErr, statusCode: 400, @@ -2293,6 +2349,7 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) { }`) return test{ ctx: ctx, + adminDB: &admin.MockDB{}, body: body, err: adminErr, statusCode: 400, @@ -2321,7 +2378,8 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) { body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, acmeDB: &acme.MockDB{ MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { assert.Equal(t, "provID", provisionerID) @@ -2355,7 +2413,8 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) { body, err := protojson.Marshal(policy) assert.NoError(t, err) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, acmeDB: &acme.MockDB{ MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { assert.Equal(t, "provID", provisionerID) @@ -2378,11 +2437,12 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - - par := NewPolicyAdminResponder(nil, tc.adminDB, tc.acmeDB) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() par.UpdateACMEAccountPolicy(w, req) @@ -2462,6 +2522,7 @@ func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) { err.Message = "ACME EAK policy does not exist" return test{ ctx: ctx, + adminDB: &admin.MockDB{}, err: err, statusCode: 404, } @@ -2487,7 +2548,8 @@ func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) { err := admin.NewErrorISE("error deleting ACME EAK policy: force") err.Message = "error deleting ACME EAK policy: force" return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, acmeDB: &acme.MockDB{ MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { assert.Equal(t, "provID", provisionerID) @@ -2518,7 +2580,8 @@ func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) { ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) return test{ - ctx: ctx, + ctx: ctx, + adminDB: &admin.MockDB{}, acmeDB: &acme.MockDB{ MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { assert.Equal(t, "provID", provisionerID) @@ -2533,11 +2596,12 @@ func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - - par := NewPolicyAdminResponder(nil, tc.adminDB, tc.acmeDB) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + ctx = acme.NewDatabaseContext(ctx, tc.acmeDB) + par := NewPolicyAdminResponder() req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() par.DeleteACMEAccountPolicy(w, req) diff --git a/authority/admin/api/provisioner.go b/authority/admin/api/provisioner.go index 1cad62dd..149f2c6a 100644 --- a/authority/admin/api/provisioner.go +++ b/authority/admin/api/provisioner.go @@ -23,29 +23,31 @@ type GetProvisionersResponse struct { } // GetProvisioner returns the requested provisioner, or an error. -func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - id := r.URL.Query().Get("id") - name := chi.URLParam(r, "name") - +func GetProvisioner(w http.ResponseWriter, r *http.Request) { var ( p provisioner.Interface err error ) + + ctx := r.Context() + id := r.URL.Query().Get("id") + name := chi.URLParam(r, "name") + auth := mustAuthority(ctx) + db := admin.MustFromContext(ctx) + if len(id) > 0 { - if p, err = h.auth.LoadProvisionerByID(id); err != nil { + if p, err = auth.LoadProvisionerByID(id); err != nil { render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) return } } else { - if p, err = h.auth.LoadProvisionerByName(name); err != nil { + if p, err = auth.LoadProvisionerByName(name); err != nil { render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) return } } - prov, err := h.adminDB.GetProvisioner(ctx, p.GetID()) + prov, err := db.GetProvisioner(ctx, p.GetID()) if err != nil { render.Error(w, err) return @@ -54,7 +56,7 @@ func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) { } // GetProvisioners returns the given segment of provisioners associated with the authority. -func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { +func GetProvisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := api.ParseCursor(r) if err != nil { render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, @@ -62,7 +64,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { return } - p, next, err := h.auth.GetProvisioners(cursor, limit) + p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit) if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -74,7 +76,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { } // CreateProvisioner creates a new prov. -func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { +func CreateProvisioner(w http.ResponseWriter, r *http.Request) { var prov = new(linkedca.Provisioner) if err := read.ProtoJSON(r.Body, prov); err != nil { render.Error(w, err) @@ -87,7 +89,7 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { return } - if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil { + if err := mustAuthority(r.Context()).StoreProvisioner(r.Context(), prov); err != nil { render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name)) return } @@ -95,27 +97,29 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { } // DeleteProvisioner deletes a provisioner. -func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) { - id := r.URL.Query().Get("id") - name := chi.URLParam(r, "name") - +func DeleteProvisioner(w http.ResponseWriter, r *http.Request) { var ( p provisioner.Interface err error ) + + id := r.URL.Query().Get("id") + name := chi.URLParam(r, "name") + auth := mustAuthority(r.Context()) + if len(id) > 0 { - if p, err = h.auth.LoadProvisionerByID(id); err != nil { + if p, err = auth.LoadProvisionerByID(id); err != nil { render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) return } } else { - if p, err = h.auth.LoadProvisionerByName(name); err != nil { + if p, err = auth.LoadProvisionerByName(name); err != nil { render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) return } } - if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil { + if err := auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil { render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName())) return } @@ -124,23 +128,27 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) { } // UpdateProvisioner updates an existing prov. -func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) { +func UpdateProvisioner(w http.ResponseWriter, r *http.Request) { var nu = new(linkedca.Provisioner) if err := read.ProtoJSON(r.Body, nu); err != nil { render.Error(w, err) return } + ctx := r.Context() name := chi.URLParam(r, "name") - _old, err := h.auth.LoadProvisionerByName(name) + auth := mustAuthority(ctx) + db := admin.MustFromContext(ctx) + + p, err := auth.LoadProvisionerByName(name) if err != nil { render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name)) return } - old, err := h.adminDB.GetProvisioner(r.Context(), _old.GetID()) + old, err := db.GetProvisioner(r.Context(), p.GetID()) if err != nil { - render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID())) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", p.GetID())) return } @@ -171,7 +179,7 @@ func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) { return } - if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil { + if err := auth.UpdateProvisioner(r.Context(), nu); err != nil { render.Error(w, err) return } diff --git a/authority/admin/api/provisioner_test.go b/authority/admin/api/provisioner_test.go index 486b6cda..d050bca6 100644 --- a/authority/admin/api/provisioner_test.go +++ b/authority/admin/api/provisioner_test.go @@ -50,6 +50,7 @@ func TestHandler_GetProvisioner(t *testing.T) { ctx: ctx, req: req, auth: auth, + adminDB: &admin.MockDB{}, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), @@ -74,6 +75,7 @@ func TestHandler_GetProvisioner(t *testing.T) { ctx: ctx, req: req, auth: auth, + adminDB: &admin.MockDB{}, statusCode: 500, err: &admin.Error{ Type: admin.ErrorServerInternalType.String(), @@ -156,13 +158,11 @@ func TestHandler_GetProvisioner(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - adminDB: tc.adminDB, - } - req := tc.req.WithContext(tc.ctx) + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) + req := tc.req.WithContext(ctx) w := httptest.NewRecorder() - h.GetProvisioner(w, req) + GetProvisioner(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -280,12 +280,10 @@ func TestHandler_GetProvisioners(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := tc.req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetProvisioners(w, req) + GetProvisioners(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -405,13 +403,11 @@ func TestHandler_CreateProvisioner(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.CreateProvisioner(w, req) + CreateProvisioner(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -571,12 +567,10 @@ func TestHandler_DeleteProvisioner(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - } + mockMustAuthority(t, tc.auth) req := tc.req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.DeleteProvisioner(w, req) + DeleteProvisioner(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) @@ -625,6 +619,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) { return test{ ctx: context.Background(), body: body, + adminDB: &admin.MockDB{}, statusCode: 400, err: &admin.Error{ Type: "badRequest", @@ -654,6 +649,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) { return test{ ctx: ctx, body: body, + adminDB: &admin.MockDB{}, auth: auth, statusCode: 500, err: &admin.Error{ @@ -1061,14 +1057,12 @@ func TestHandler_UpdateProvisioner(t *testing.T) { for name, prep := range tests { tc := prep(t) t.Run(name, func(t *testing.T) { - h := &Handler{ - auth: tc.auth, - adminDB: tc.adminDB, - } + mockMustAuthority(t, tc.auth) + ctx := admin.NewContext(tc.ctx, tc.adminDB) req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) - req = req.WithContext(tc.ctx) + req = req.WithContext(ctx) w := httptest.NewRecorder() - h.UpdateProvisioner(w, req) + UpdateProvisioner(w, req) res := w.Result() assert.Equals(t, tc.statusCode, res.StatusCode) diff --git a/authority/admin/db.go b/authority/admin/db.go index 0c0e7767..b331cc0a 100644 --- a/authority/admin/db.go +++ b/authority/admin/db.go @@ -76,6 +76,29 @@ type DB interface { DeleteAuthorityPolicy(ctx context.Context) error } +type dbKey struct{} + +// NewContext adds the given admin database to the context. +func NewContext(ctx context.Context, db DB) context.Context { + return context.WithValue(ctx, dbKey{}, db) +} + +// FromContext returns the current admin database from the given context. +func FromContext(ctx context.Context) (db DB, ok bool) { + db, ok = ctx.Value(dbKey{}).(DB) + return +} + +// MustFromContext returns the current admin database from the given context. It +// will panic if it's not in the context. +func MustFromContext(ctx context.Context) DB { + if db, ok := FromContext(ctx); !ok { + panic("admin database is not in the context") + } else { + return db + } +} + // MockDB is an implementation of the DB interface that should only be used as // a mock in tests. type MockDB struct { diff --git a/authority/authority.go b/authority/authority.go index 3ce5acfd..5fa4b0fc 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -167,6 +167,29 @@ func NewEmbedded(opts ...Option) (*Authority, error) { return a, nil } +type authorityKey struct{} + +// NewContext adds the given authority to the context. +func NewContext(ctx context.Context, a *Authority) context.Context { + return context.WithValue(ctx, authorityKey{}, a) +} + +// FromContext returns the current authority from the given context. +func FromContext(ctx context.Context) (a *Authority, ok bool) { + a, ok = ctx.Value(authorityKey{}).(*Authority) + return +} + +// MustFromContext returns the current authority from the given context. It will +// panic if the authority is not in the context. +func MustFromContext(ctx context.Context) *Authority { + if a, ok := FromContext(ctx); !ok { + panic("authority is not in the context") + } else { + return a + } +} + // ReloadAdminResources reloads admins and provisioners from the DB. func (a *Authority) ReloadAdminResources(ctx context.Context) error { var ( @@ -235,6 +258,7 @@ func (a *Authority) init() error { } var err error + ctx := NewContext(context.Background(), a) // Set password if they are not set. var configPassword []byte @@ -270,7 +294,7 @@ func (a *Authority) init() error { if a.config.KMS != nil { options = *a.config.KMS } - a.keyManager, err = kms.New(context.Background(), options) + a.keyManager, err = kms.New(ctx, options) if err != nil { return err } @@ -300,7 +324,7 @@ func (a *Authority) init() error { // Configure linked RA if linkedcaClient != nil && options.CertificateAuthority == "" { - conf, err := linkedcaClient.GetConfiguration(context.Background()) + conf, err := linkedcaClient.GetConfiguration(ctx) if err != nil { return err } @@ -334,7 +358,7 @@ func (a *Authority) init() error { } } - a.x509CAService, err = cas.New(context.Background(), options) + a.x509CAService, err = cas.New(ctx, options) if err != nil { return err } @@ -521,7 +545,7 @@ func (a *Authority) init() error { } } - a.scepService, err = scep.NewService(context.Background(), options) + a.scepService, err = scep.NewService(ctx, options) if err != nil { return err } @@ -543,19 +567,19 @@ func (a *Authority) init() error { } } - provs, err := a.adminDB.GetProvisioners(context.Background()) + provs, err := a.adminDB.GetProvisioners(ctx) if err != nil { return admin.WrapErrorISE(err, "error loading provisioners to initialize authority") } if len(provs) == 0 && !strings.EqualFold(a.config.AuthorityConfig.DeploymentType, "linked") { // Create First Provisioner - prov, err := CreateFirstProvisioner(context.Background(), a.adminDB, string(a.password)) + prov, err := CreateFirstProvisioner(ctx, a.adminDB, string(a.password)) if err != nil { return admin.WrapErrorISE(err, "error creating first provisioner") } // Create first admin - if err := a.adminDB.CreateAdmin(context.Background(), &linkedca.Admin{ + if err := a.adminDB.CreateAdmin(ctx, &linkedca.Admin{ ProvisionerId: prov.Id, Subject: "step", Type: linkedca.Admin_SUPER_ADMIN, @@ -571,7 +595,7 @@ func (a *Authority) init() error { } // Load x509 and SSH Policy Engines - if err := a.reloadPolicyEngines(context.Background()); err != nil { + if err := a.reloadPolicyEngines(ctx); err != nil { return err } @@ -596,6 +620,15 @@ func (a *Authority) init() error { return nil } +// GetID returns the define authority id or a zero uuid. +func (a *Authority) GetID() string { + const zeroUUID = "00000000-0000-0000-0000-000000000000" + if id := a.config.AuthorityConfig.AuthorityID; id != "" { + return id + } + return zeroUUID +} + // GetDatabase returns the authority database. If the configuration does not // define a database, GetDatabase will return a db.SimpleDB instance. func (a *Authority) GetDatabase() db.AuthDB { diff --git a/authority/authority_test.go b/authority/authority_test.go index 1f63333d..9f35f23e 100644 --- a/authority/authority_test.go +++ b/authority/authority_test.go @@ -14,6 +14,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" "go.step.sm/crypto/jose" @@ -421,3 +422,31 @@ func TestAuthority_GetSCEPService(t *testing.T) { }) } } + +func TestAuthority_GetID(t *testing.T) { + type fields struct { + authorityID string + } + tests := []struct { + name string + fields fields + want string + }{ + {"ok", fields{""}, "00000000-0000-0000-0000-000000000000"}, + {"ok with id", fields{"10b9a431-ed3b-4a5f-abee-ec35119b65e7"}, "10b9a431-ed3b-4a5f-abee-ec35119b65e7"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + AuthorityID: tt.fields.authorityID, + }, + }, + } + if got := a.GetID(); got != tt.want { + t.Errorf("Authority.GetID() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/authorize.go b/authority/authorize.go index 21e02069..e23f2e5f 100644 --- a/authority/authorize.go +++ b/authority/authorize.go @@ -260,8 +260,7 @@ func (a *Authority) authorizeSign(ctx context.Context, token string) ([]provisio // AuthorizeSign authorizes a signature request by validating and authenticating // a token that must be sent w/ the request. // -// NOTE: This method is deprecated and should not be used. We make it available -// in the short term os as not to break existing clients. +// Deprecated: Use Authorize(context.Context, string) ([]provisioner.SignOption, error). func (a *Authority) AuthorizeSign(token string) ([]provisioner.SignOption, error) { ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) return a.Authorize(ctx, token) diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index 10e22519..ccbdbc22 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -54,7 +54,11 @@ func startCABootstrapServer() *httptest.Server { if err != nil { panic(err) } + baseContext := buildContext(ca.auth, nil, nil, nil) srv.Config.Handler = ca.srv.Handler + srv.Config.BaseContext = func(net.Listener) context.Context { + return baseContext + } srv.TLS = ca.srv.TLSConfig srv.StartTLS() // Force the use of GetCertificate on IPs diff --git a/ca/ca.go b/ca/ca.go index 3cb4646b..9252fff7 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -1,10 +1,12 @@ package ca import ( + "context" "crypto/tls" "crypto/x509" "fmt" "log" + "net" "net/http" "net/url" "reflect" @@ -18,6 +20,7 @@ import ( acmeNoSQL "github.com/smallstep/certificates/acme/db/nosql" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/admin" adminAPI "github.com/smallstep/certificates/authority/admin/api" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/db" @@ -170,10 +173,9 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { insecureHandler := http.Handler(insecureMux) // Add regular CA api endpoints in / and /1.0 - routerHandler := api.New(auth) - routerHandler.Route(mux) + api.Route(mux) mux.Route("/1.0", func(r chi.Router) { - routerHandler.Route(r) + api.Route(r) }) //Add ACME api endpoints in /acme and /1.0/acme @@ -187,49 +189,41 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { dns = fmt.Sprintf("%s:%s", dns, port) } - // ACME Router - prefix := "acme" + // ACME Router is only available if we have a database. var acmeDB acme.DB - if cfg.DB == nil { - acmeDB = nil - } else { + var acmeLinker acme.Linker + if cfg.DB != nil { acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB)) if err != nil { return nil, errors.Wrap(err, "error configuring ACME DB interface") } + acmeLinker = acme.NewLinker(dns, "acme") + mux.Route("/acme", func(r chi.Router) { + acmeAPI.Route(r) + }) + // Use 2.0 because, at the moment, our ACME api is only compatible with v2.0 + // of the ACME spec. + mux.Route("/2.0/acme", func(r chi.Router) { + acmeAPI.Route(r) + }) } - acmeHandler := acmeAPI.NewHandler(acmeAPI.HandlerOptions{ - Backdate: *cfg.AuthorityConfig.Backdate, - DB: acmeDB, - DNS: dns, - Prefix: prefix, - CA: auth, - }) - mux.Route("/"+prefix, func(r chi.Router) { - acmeHandler.Route(r) - }) - // Use 2.0 because, at the moment, our ACME api is only compatible with v2.0 - // of the ACME spec. - mux.Route("/2.0/"+prefix, func(r chi.Router) { - acmeHandler.Route(r) - }) // Admin API Router if cfg.AuthorityConfig.EnableAdmin { adminDB := auth.GetAdminDatabase() if adminDB != nil { acmeAdminResponder := adminAPI.NewACMEAdminResponder() - policyAdminResponder := adminAPI.NewPolicyAdminResponder(auth, adminDB, acmeDB) - adminHandler := adminAPI.NewHandler(auth, adminDB, acmeDB, acmeAdminResponder, policyAdminResponder) + policyAdminResponder := adminAPI.NewPolicyAdminResponder() mux.Route("/admin", func(r chi.Router) { - adminHandler.Route(r) + adminAPI.Route(r, acmeAdminResponder, policyAdminResponder) }) } } + var scepAuthority *scep.Authority if ca.shouldServeSCEPEndpoints() { scepPrefix := "scep" - scepAuthority, err := scep.New(auth, scep.AuthorityOptions{ + scepAuthority, err = scep.New(auth, scep.AuthorityOptions{ Service: auth.GetSCEPService(), DNS: dns, Prefix: scepPrefix, @@ -237,13 +231,12 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { if err != nil { return nil, errors.Wrap(err, "error creating SCEP authority") } - scepRouterHandler := scepAPI.New(scepAuthority) // According to the RFC (https://tools.ietf.org/html/rfc8894#section-7.10), // SCEP operations are performed using HTTP, so that's why the API is mounted // to the insecure mux. insecureMux.Route("/"+scepPrefix, func(r chi.Router) { - scepRouterHandler.Route(r) + scepAPI.Route(r) }) // The RFC also mentions usage of HTTPS, but seems to advise @@ -253,7 +246,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { // as well as HTTPS can be used to request certificates // using SCEP. mux.Route("/"+scepPrefix, func(r chi.Router) { - scepRouterHandler.Route(r) + scepAPI.Route(r) }) } @@ -280,7 +273,13 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { insecureHandler = logger.Middleware(insecureHandler) } + // Create context with all the necessary values. + baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker) + ca.srv = server.New(cfg.Address, handler, tlsConfig) + ca.srv.BaseContext = func(net.Listener) context.Context { + return baseContext + } // only start the insecure server if the insecure address is configured // and, currently, also only when it should serve SCEP endpoints. @@ -290,11 +289,32 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { // will probably introduce more complexity in terms of graceful // reload. ca.insecureSrv = server.New(cfg.InsecureAddress, insecureHandler, nil) + ca.insecureSrv.BaseContext = func(net.Listener) context.Context { + return baseContext + } } return ca, nil } +// buildContext builds the server base context. +func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB, acmeLinker acme.Linker) context.Context { + ctx := authority.NewContext(context.Background(), a) + if authDB := a.GetDatabase(); authDB != nil { + ctx = db.NewContext(ctx, authDB) + } + if adminDB := a.GetAdminDatabase(); adminDB != nil { + ctx = admin.NewContext(ctx, adminDB) + } + if scepAuthority != nil { + ctx = scep.NewContext(ctx, scepAuthority) + } + if acmeDB != nil { + ctx = acme.NewContext(ctx, acmeDB, acme.NewClient(), acmeLinker, nil) + } + return ctx +} + // Run starts the CA calling to the server ListenAndServe method. func (ca *CA) Run() error { var wg sync.WaitGroup diff --git a/ca/ca_test.go b/ca/ca_test.go index e4c35a90..29eac575 100644 --- a/ca/ca_test.go +++ b/ca/ca_test.go @@ -2,6 +2,7 @@ package ca import ( "bytes" + "context" "crypto" "crypto/rand" "crypto/sha1" @@ -281,7 +282,8 @@ ZEp7knvU2psWRw== assert.FatalError(t, err) rr := httptest.NewRecorder() - tc.ca.srv.Handler.ServeHTTP(rr, rq) + ctx := authority.NewContext(context.Background(), tc.ca.auth) + tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx)) if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} @@ -360,7 +362,8 @@ func TestCAProvisioners(t *testing.T) { assert.FatalError(t, err) rr := httptest.NewRecorder() - tc.ca.srv.Handler.ServeHTTP(rr, rq) + ctx := authority.NewContext(context.Background(), tc.ca.auth) + tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx)) if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} @@ -426,7 +429,8 @@ func TestCAProvisionerEncryptedKey(t *testing.T) { assert.FatalError(t, err) rr := httptest.NewRecorder() - tc.ca.srv.Handler.ServeHTTP(rr, rq) + ctx := authority.NewContext(context.Background(), tc.ca.auth) + tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx)) if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} @@ -487,7 +491,8 @@ func TestCARoot(t *testing.T) { assert.FatalError(t, err) rr := httptest.NewRecorder() - tc.ca.srv.Handler.ServeHTTP(rr, rq) + ctx := authority.NewContext(context.Background(), tc.ca.auth) + tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx)) if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} @@ -534,7 +539,8 @@ func TestCAHealth(t *testing.T) { assert.FatalError(t, err) rr := httptest.NewRecorder() - tc.ca.srv.Handler.ServeHTTP(rr, rq) + ctx := authority.NewContext(context.Background(), tc.ca.auth) + tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx)) if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} @@ -628,7 +634,8 @@ func TestCARenew(t *testing.T) { rq.TLS = tc.tlsConnState rr := httptest.NewRecorder() - tc.ca.srv.Handler.ServeHTTP(rr, rq) + ctx := authority.NewContext(context.Background(), tc.ca.auth) + tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx)) if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} diff --git a/ca/tls_test.go b/ca/tls_test.go index 93dbe9b3..946a6cb5 100644 --- a/ca/tls_test.go +++ b/ca/tls_test.go @@ -10,6 +10,7 @@ import ( "encoding/hex" "io" "log" + "net" "net/http" "net/http/httptest" "reflect" @@ -77,7 +78,12 @@ func startCATestServer() *httptest.Server { panic(err) } // Use a httptest.Server instead - return startTestServer(ca.srv.TLSConfig, ca.srv.Handler) + srv := startTestServer(ca.srv.TLSConfig, ca.srv.Handler) + baseContext := buildContext(ca.auth, nil, nil, nil) + srv.Config.BaseContext = func(net.Listener) context.Context { + return baseContext + } + return srv } func sign(domain string) (*Client, *api.SignResponse, crypto.PrivateKey) { diff --git a/cas/vaultcas/auth/approle/approle.go b/cas/vaultcas/auth/approle/approle.go new file mode 100644 index 00000000..118afb10 --- /dev/null +++ b/cas/vaultcas/auth/approle/approle.go @@ -0,0 +1,67 @@ +package approle + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/hashicorp/vault/api/auth/approle" +) + +// AuthOptions defines the configuration options added using the +// VaultOptions.AuthOptions field when AuthType is approle +type AuthOptions struct { + RoleID string `json:"roleID,omitempty"` + SecretID string `json:"secretID,omitempty"` + SecretIDFile string `json:"secretIDFile,omitempty"` + SecretIDEnv string `json:"secretIDEnv,omitempty"` + IsWrappingToken bool `json:"isWrappingToken,omitempty"` +} + +func NewApproleAuthMethod(mountPath string, options json.RawMessage) (*approle.AppRoleAuth, error) { + var opts *AuthOptions + + err := json.Unmarshal(options, &opts) + if err != nil { + return nil, fmt.Errorf("error decoding AppRole auth options: %w", err) + } + + var approleAuth *approle.AppRoleAuth + + var loginOptions []approle.LoginOption + if mountPath != "" { + loginOptions = append(loginOptions, approle.WithMountPath(mountPath)) + } + if opts.IsWrappingToken { + loginOptions = append(loginOptions, approle.WithWrappingToken()) + } + + if opts.RoleID == "" { + return nil, errors.New("you must set roleID") + } + + var sid approle.SecretID + switch { + case opts.SecretID != "" && opts.SecretIDFile == "" && opts.SecretIDEnv == "": + sid = approle.SecretID{ + FromString: opts.SecretID, + } + case opts.SecretIDFile != "" && opts.SecretID == "" && opts.SecretIDEnv == "": + sid = approle.SecretID{ + FromFile: opts.SecretIDFile, + } + case opts.SecretIDEnv != "" && opts.SecretIDFile == "" && opts.SecretID == "": + sid = approle.SecretID{ + FromEnv: opts.SecretIDEnv, + } + default: + return nil, errors.New("you must set one of secretID, secretIDFile or secretIDEnv") + } + + approleAuth, err = approle.NewAppRoleAuth(opts.RoleID, &sid, loginOptions...) + if err != nil { + return nil, fmt.Errorf("unable to initialize Kubernetes auth method: %w", err) + } + + return approleAuth, nil +} diff --git a/cas/vaultcas/auth/approle/approle_test.go b/cas/vaultcas/auth/approle/approle_test.go new file mode 100644 index 00000000..28b7b7f7 --- /dev/null +++ b/cas/vaultcas/auth/approle/approle_test.go @@ -0,0 +1,195 @@ +package approle + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + vault "github.com/hashicorp/vault/api" +) + +func testCAHelper(t *testing.T) (*url.URL, *vault.Client) { + t.Helper() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.RequestURI == "/v1/auth/approle/login": + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{ + "auth": { + "client_token": "hvs.0000" + } + }`) + case r.RequestURI == "/v1/auth/custom-approle/login": + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{ + "auth": { + "client_token": "hvs.9999" + } + }`) + default: + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, `{"error":"not found"}`) + } + })) + t.Cleanup(func() { + srv.Close() + }) + u, err := url.Parse(srv.URL) + if err != nil { + srv.Close() + t.Fatal(err) + } + + config := vault.DefaultConfig() + config.Address = srv.URL + + client, err := vault.NewClient(config) + if err != nil { + srv.Close() + t.Fatal(err) + } + + return u, client +} + +func TestApprole_LoginMountPaths(t *testing.T) { + caURL, _ := testCAHelper(t) + + config := vault.DefaultConfig() + config.Address = caURL.String() + client, _ := vault.NewClient(config) + + tests := []struct { + name string + mountPath string + token string + }{ + { + name: "ok default mount path", + mountPath: "", + token: "hvs.0000", + }, + { + name: "ok explicit mount path", + mountPath: "approle", + token: "hvs.0000", + }, + { + name: "ok custom mount path", + mountPath: "custom-approle", + token: "hvs.9999", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + method, err := NewApproleAuthMethod(tt.mountPath, json.RawMessage(`{"RoleID":"roleID","SecretID":"secretID","IsWrappingToken":false}`)) + if err != nil { + t.Errorf("NewApproleAuthMethod() error = %v", err) + return + } + + secret, err := client.Auth().Login(context.Background(), method) + if err != nil { + t.Errorf("Login() error = %v", err) + return + } + + token, _ := secret.TokenID() + if token != tt.token { + t.Errorf("Token error got %v, expected %v", token, tt.token) + return + } + }) + } +} + +func TestApprole_NewApproleAuthMethod(t *testing.T) { + tests := []struct { + name string + mountPath string + raw string + wantErr bool + }{ + { + "ok secret-id string", + "", + `{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000"}`, + false, + }, + { + "ok secret-id string and wrapped", + "", + `{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "isWrappedToken": true}`, + false, + }, + { + "ok secret-id string and wrapped with custom mountPath", + "approle2", + `{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "isWrappedToken": true}`, + false, + }, + { + "ok secret-id file", + "", + `{"RoleID": "0000-0000-0000-0000", "SecretIDFile": "./secret-id"}`, + false, + }, + { + "ok secret-id env", + "", + `{"RoleID": "0000-0000-0000-0000", "SecretIDEnv": "VAULT_APPROLE_SECRETID"}`, + false, + }, + { + "fail mandatory role-id", + "", + `{}`, + true, + }, + { + "fail mandatory secret-id any", + "", + `{"RoleID": "0000-0000-0000-0000"}`, + true, + }, + { + "fail multiple secret-id types id and env", + "", + `{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "SecretIDEnv": "VAULT_APPROLE_SECRETID"}`, + true, + }, + { + "fail multiple secret-id types id and file", + "", + `{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "SecretIDFile": "./secret-id"}`, + true, + }, + { + "fail multiple secret-id types env and file", + "", + `{"RoleID": "0000-0000-0000-0000", "SecretIDFile": "./secret-id", "SecretIDEnv": "VAULT_APPROLE_SECRETID"}`, + true, + }, + { + "fail multiple secret-id types all", + "", + `{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "SecretIDFile": "./secret-id", "SecretIDEnv": "VAULT_APPROLE_SECRETID"}`, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewApproleAuthMethod(tt.mountPath, json.RawMessage(tt.raw)) + if (err != nil) != tt.wantErr { + t.Errorf("Approle.NewApproleAuthMethod() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} diff --git a/cas/vaultcas/auth/kubernetes/kubernetes.go b/cas/vaultcas/auth/kubernetes/kubernetes.go new file mode 100644 index 00000000..267bcdca --- /dev/null +++ b/cas/vaultcas/auth/kubernetes/kubernetes.go @@ -0,0 +1,49 @@ +package kubernetes + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/hashicorp/vault/api/auth/kubernetes" +) + +// AuthOptions defines the configuration options added using the +// VaultOptions.AuthOptions field when AuthType is kubernetes +type AuthOptions struct { + Role string `json:"role,omitempty"` + TokenPath string `json:"tokenPath,omitempty"` +} + +func NewKubernetesAuthMethod(mountPath string, options json.RawMessage) (*kubernetes.KubernetesAuth, error) { + var opts *AuthOptions + + err := json.Unmarshal(options, &opts) + if err != nil { + return nil, fmt.Errorf("error decoding Kubernetes auth options: %w", err) + } + + var kubernetesAuth *kubernetes.KubernetesAuth + + var loginOptions []kubernetes.LoginOption + if mountPath != "" { + loginOptions = append(loginOptions, kubernetes.WithMountPath(mountPath)) + } + if opts.TokenPath != "" { + loginOptions = append(loginOptions, kubernetes.WithServiceAccountTokenPath(opts.TokenPath)) + } + + if opts.Role == "" { + return nil, errors.New("you must set role") + } + + kubernetesAuth, err = kubernetes.NewKubernetesAuth( + opts.Role, + loginOptions..., + ) + if err != nil { + return nil, fmt.Errorf("unable to initialize Kubernetes auth method: %w", err) + } + + return kubernetesAuth, nil +} diff --git a/cas/vaultcas/auth/kubernetes/kubernetes_test.go b/cas/vaultcas/auth/kubernetes/kubernetes_test.go new file mode 100644 index 00000000..55be904d --- /dev/null +++ b/cas/vaultcas/auth/kubernetes/kubernetes_test.go @@ -0,0 +1,149 @@ +package kubernetes + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "path" + "path/filepath" + "runtime" + "testing" + + vault "github.com/hashicorp/vault/api" +) + +func testCAHelper(t *testing.T) (*url.URL, *vault.Client) { + t.Helper() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.RequestURI == "/v1/auth/kubernetes/login": + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{ + "auth": { + "client_token": "hvs.0000" + } + }`) + case r.RequestURI == "/v1/auth/custom-kubernetes/login": + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{ + "auth": { + "client_token": "hvs.9999" + } + }`) + default: + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, `{"error":"not found"}`) + } + })) + t.Cleanup(func() { + srv.Close() + }) + u, err := url.Parse(srv.URL) + if err != nil { + srv.Close() + t.Fatal(err) + } + + config := vault.DefaultConfig() + config.Address = srv.URL + + client, err := vault.NewClient(config) + if err != nil { + srv.Close() + t.Fatal(err) + } + + return u, client +} + +func TestApprole_LoginMountPaths(t *testing.T) { + caURL, _ := testCAHelper(t) + _, filename, _, _ := runtime.Caller(0) + tokenPath := filepath.Join(path.Dir(filename), "token") + + config := vault.DefaultConfig() + config.Address = caURL.String() + client, _ := vault.NewClient(config) + + tests := []struct { + name string + mountPath string + token string + }{ + { + name: "ok default mount path", + mountPath: "", + token: "hvs.0000", + }, + { + name: "ok explicit mount path", + mountPath: "kubernetes", + token: "hvs.0000", + }, + { + name: "ok custom mount path", + mountPath: "custom-kubernetes", + token: "hvs.9999", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + method, err := NewKubernetesAuthMethod(tt.mountPath, json.RawMessage(`{"role": "SomeRoleName", "tokenPath": "`+tokenPath+`"}`)) + if err != nil { + t.Errorf("NewApproleAuthMethod() error = %v", err) + return + } + + secret, err := client.Auth().Login(context.Background(), method) + if err != nil { + t.Errorf("Login() error = %v", err) + return + } + + token, _ := secret.TokenID() + if token != tt.token { + t.Errorf("Token error got %v, expected %v", token, tt.token) + return + } + }) + } +} + +func TestApprole_NewApproleAuthMethod(t *testing.T) { + _, filename, _, _ := runtime.Caller(0) + tokenPath := filepath.Join(path.Dir(filename), "token") + + tests := []struct { + name string + mountPath string + raw string + wantErr bool + }{ + { + "ok secret-id string", + "", + `{"role": "SomeRoleName", "tokenPath": "` + tokenPath + `"}`, + false, + }, + { + "fail mandatory role", + "", + `{}`, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewKubernetesAuthMethod(tt.mountPath, json.RawMessage(tt.raw)) + if (err != nil) != tt.wantErr { + t.Errorf("Kubernetes.NewKubernetesAuthMethod() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} diff --git a/cas/vaultcas/auth/kubernetes/token b/cas/vaultcas/auth/kubernetes/token new file mode 100644 index 00000000..6745be67 --- /dev/null +++ b/cas/vaultcas/auth/kubernetes/token @@ -0,0 +1 @@ +token \ No newline at end of file diff --git a/cas/vaultcas/vaultcas.go b/cas/vaultcas/vaultcas.go index c29ef691..a5658620 100644 --- a/cas/vaultcas/vaultcas.go +++ b/cas/vaultcas/vaultcas.go @@ -15,9 +15,10 @@ import ( "time" "github.com/smallstep/certificates/cas/apiv1" + "github.com/smallstep/certificates/cas/vaultcas/auth/approle" + "github.com/smallstep/certificates/cas/vaultcas/auth/kubernetes" vault "github.com/hashicorp/vault/api" - auth "github.com/hashicorp/vault/api/auth/approle" ) func init() { @@ -29,15 +30,14 @@ func init() { // VaultOptions defines the configuration options added using the // apiv1.Options.Config field. type VaultOptions struct { - PKI string `json:"pki,omitempty"` - PKIRoleDefault string `json:"pkiRoleDefault,omitempty"` - PKIRoleRSA string `json:"pkiRoleRSA,omitempty"` - PKIRoleEC string `json:"pkiRoleEC,omitempty"` - PKIRoleEd25519 string `json:"pkiRoleEd25519,omitempty"` - RoleID string `json:"roleID,omitempty"` - SecretID auth.SecretID `json:"secretID,omitempty"` - AppRole string `json:"appRole,omitempty"` - IsWrappingToken bool `json:"isWrappingToken,omitempty"` + PKIMountPath string `json:"pkiMountPath,omitempty"` + PKIRoleDefault string `json:"pkiRoleDefault,omitempty"` + PKIRoleRSA string `json:"pkiRoleRSA,omitempty"` + PKIRoleEC string `json:"pkiRoleEC,omitempty"` + PKIRoleEd25519 string `json:"pkiRoleEd25519,omitempty"` + AuthType string `json:"authType,omitempty"` + AuthMountPath string `json:"authMountPath,omitempty"` + AuthOptions json.RawMessage `json:"authOptions,omitempty"` } // VaultCAS implements a Certificate Authority Service using Hashicorp Vault. @@ -77,28 +77,22 @@ func New(ctx context.Context, opts apiv1.Options) (*VaultCAS, error) { return nil, fmt.Errorf("unable to initialize vault client: %w", err) } - var appRoleAuth *auth.AppRoleAuth - if vc.IsWrappingToken { - appRoleAuth, err = auth.NewAppRoleAuth( - vc.RoleID, - &vc.SecretID, - auth.WithWrappingToken(), - auth.WithMountPath(vc.AppRole), - ) - } else { - appRoleAuth, err = auth.NewAppRoleAuth( - vc.RoleID, - &vc.SecretID, - auth.WithMountPath(vc.AppRole), - ) + var method vault.AuthMethod + switch vc.AuthType { + case "kubernetes": + method, err = kubernetes.NewKubernetesAuthMethod(vc.AuthMountPath, vc.AuthOptions) + case "approle": + method, err = approle.NewApproleAuthMethod(vc.AuthMountPath, vc.AuthOptions) + default: + return nil, fmt.Errorf("unknown auth type: %s, only 'kubernetes' and 'approle' currently supported", vc.AuthType) } if err != nil { - return nil, fmt.Errorf("unable to initialize AppRole auth method: %w", err) + return nil, fmt.Errorf("unable to configure %s auth method: %w", vc.AuthType, err) } - authInfo, err := client.Auth().Login(ctx, appRoleAuth) + authInfo, err := client.Auth().Login(ctx, method) if err != nil { - return nil, fmt.Errorf("unable to login to AppRole auth method: %w", err) + return nil, fmt.Errorf("unable to login to %s auth method: %w", vc.AuthType, err) } if authInfo == nil { return nil, errors.New("no auth info was returned after login") @@ -134,7 +128,7 @@ func (v *VaultCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv // GetCertificateAuthority returns the root certificate of the certificate // authority using the configured fingerprint. func (v *VaultCAS) GetCertificateAuthority(req *apiv1.GetCertificateAuthorityRequest) (*apiv1.GetCertificateAuthorityResponse, error) { - secret, err := v.client.Logical().Read(v.config.PKI + "/cert/ca_chain") + secret, err := v.client.Logical().Read(v.config.PKIMountPath + "/cert/ca_chain") if err != nil { return nil, fmt.Errorf("error reading ca chain: %w", err) } @@ -190,7 +184,7 @@ func (v *VaultCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv vaultReq := map[string]interface{}{ "serial_number": formatSerialNumber(sn), } - _, err := v.client.Logical().Write(v.config.PKI+"/revoke/", vaultReq) + _, err := v.client.Logical().Write(v.config.PKIMountPath+"/revoke/", vaultReq) if err != nil { return nil, fmt.Errorf("error revoking certificate: %w", err) } @@ -224,7 +218,7 @@ func (v *VaultCAS) createCertificate(cr *x509.CertificateRequest, lifetime time. "ttl": lifetime.Seconds(), } - secret, err := v.client.Logical().Write(v.config.PKI+"/sign/"+vaultPKIRole, vaultReq) + secret, err := v.client.Logical().Write(v.config.PKIMountPath+"/sign/"+vaultPKIRole, vaultReq) if err != nil { return nil, nil, fmt.Errorf("error signing certificate: %w", err) } @@ -247,21 +241,17 @@ func (v *VaultCAS) createCertificate(cr *x509.CertificateRequest, lifetime time. } func loadOptions(config json.RawMessage) (*VaultOptions, error) { - var vc *VaultOptions + // setup default values + vc := VaultOptions{ + PKIMountPath: "pki", + PKIRoleDefault: "default", + } err := json.Unmarshal(config, &vc) if err != nil { return nil, fmt.Errorf("error decoding vaultCAS config: %w", err) } - if vc.PKI == "" { - vc.PKI = "pki" // use default pki vault name - } - - if vc.PKIRoleDefault == "" { - vc.PKIRoleDefault = "default" // use default pki role name - } - if vc.PKIRoleRSA == "" { vc.PKIRoleRSA = vc.PKIRoleDefault } @@ -272,23 +262,7 @@ func loadOptions(config json.RawMessage) (*VaultOptions, error) { vc.PKIRoleEd25519 = vc.PKIRoleDefault } - if vc.RoleID == "" { - return nil, errors.New("vaultCAS config options must define `roleID`") - } - - if vc.SecretID.FromEnv == "" && vc.SecretID.FromFile == "" && vc.SecretID.FromString == "" { - return nil, errors.New("vaultCAS config options must define `secretID` object with one of `FromEnv`, `FromFile` or `FromString`") - } - - if vc.PKI == "" { - vc.PKI = "pki" // use default pki vault name - } - - if vc.AppRole == "" { - vc.AppRole = "auth/approle" - } - - return vc, nil + return &vc, nil } func parseCertificates(pemCert string) []*x509.Certificate { diff --git a/cas/vaultcas/vaultcas_test.go b/cas/vaultcas/vaultcas_test.go index 9f73a1ee..0ea0c4b1 100644 --- a/cas/vaultcas/vaultcas_test.go +++ b/cas/vaultcas/vaultcas_test.go @@ -14,7 +14,6 @@ import ( "time" vault "github.com/hashicorp/vault/api" - auth "github.com/hashicorp/vault/api/auth/approle" "github.com/smallstep/certificates/cas/apiv1" "go.step.sm/crypto/pemutil" ) @@ -99,7 +98,7 @@ func testCAHelper(t *testing.T) (*url.URL, *vault.Client) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { - case r.RequestURI == "/v1/auth/auth/approle/login": + case r.RequestURI == "/v1/auth/approle/login": w.WriteHeader(http.StatusOK) fmt.Fprintf(w, `{ "auth": { @@ -183,11 +182,8 @@ func TestNew_register(t *testing.T) { CertificateAuthority: caURL.String(), CertificateAuthorityFingerprint: testRootFingerprint, Config: json.RawMessage(`{ - "PKI": "pki", - "PKIRoleDefault": "pki-role", - "RoleID": "roleID", - "SecretID": {"FromString": "secretID"}, - "IsWrappingToken": false + "AuthType": "approle", + "AuthOptions": {"RoleID":"roleID","SecretID":"secretID","IsWrappingToken":false} }`), }) @@ -201,15 +197,11 @@ func TestVaultCAS_CreateCertificate(t *testing.T) { _, client := testCAHelper(t) options := VaultOptions{ - PKI: "pki", - PKIRoleDefault: "role", - PKIRoleRSA: "rsa", - PKIRoleEC: "ec", - PKIRoleEd25519: "ed25519", - RoleID: "roleID", - SecretID: auth.SecretID{FromString: "secretID"}, - AppRole: "approle", - IsWrappingToken: false, + PKIMountPath: "pki", + PKIRoleDefault: "role", + PKIRoleRSA: "rsa", + PKIRoleEC: "ec", + PKIRoleEd25519: "ed25519", } type fields struct { @@ -291,7 +283,7 @@ func TestVaultCAS_GetCertificateAuthority(t *testing.T) { } options := VaultOptions{ - PKI: "pki", + PKIMountPath: "pki", } rootCert := parseCertificates(testRootCertificate)[0] @@ -335,15 +327,11 @@ func TestVaultCAS_RevokeCertificate(t *testing.T) { _, client := testCAHelper(t) options := VaultOptions{ - PKI: "pki", - PKIRoleDefault: "role", - PKIRoleRSA: "rsa", - PKIRoleEC: "ec", - PKIRoleEd25519: "ed25519", - RoleID: "roleID", - SecretID: auth.SecretID{FromString: "secretID"}, - AppRole: "approle", - IsWrappingToken: false, + PKIMountPath: "pki", + PKIRoleDefault: "role", + PKIRoleRSA: "rsa", + PKIRoleEC: "ec", + PKIRoleEd25519: "ed25519", } type fields struct { @@ -407,15 +395,11 @@ func TestVaultCAS_RenewCertificate(t *testing.T) { _, client := testCAHelper(t) options := VaultOptions{ - PKI: "pki", - PKIRoleDefault: "role", - PKIRoleRSA: "rsa", - PKIRoleEC: "ec", - PKIRoleEd25519: "ed25519", - RoleID: "roleID", - SecretID: auth.SecretID{FromString: "secretID"}, - AppRole: "approle", - IsWrappingToken: false, + PKIMountPath: "pki", + PKIRoleDefault: "role", + PKIRoleRSA: "rsa", + PKIRoleEC: "ec", + PKIRoleEd25519: "ed25519", } type fields struct { @@ -464,202 +448,66 @@ func TestVaultCAS_loadOptions(t *testing.T) { want *VaultOptions wantErr bool }{ - { - "ok mandatory with SecretID FromString", - `{"RoleID": "roleID", "SecretID": {"FromString": "secretID"}}`, - &VaultOptions{ - PKI: "pki", - PKIRoleDefault: "default", - PKIRoleRSA: "default", - PKIRoleEC: "default", - PKIRoleEd25519: "default", - RoleID: "roleID", - SecretID: auth.SecretID{FromString: "secretID"}, - AppRole: "auth/approle", - IsWrappingToken: false, - }, - false, - }, - { - "ok mandatory with SecretID FromFile", - `{"RoleID": "roleID", "SecretID": {"FromFile": "secretID"}}`, - &VaultOptions{ - PKI: "pki", - PKIRoleDefault: "default", - PKIRoleRSA: "default", - PKIRoleEC: "default", - PKIRoleEd25519: "default", - RoleID: "roleID", - SecretID: auth.SecretID{FromFile: "secretID"}, - AppRole: "auth/approle", - IsWrappingToken: false, - }, - false, - }, - { - "ok mandatory with SecretID FromEnv", - `{"RoleID": "roleID", "SecretID": {"FromEnv": "secretID"}}`, - &VaultOptions{ - PKI: "pki", - PKIRoleDefault: "default", - PKIRoleRSA: "default", - PKIRoleEC: "default", - PKIRoleEd25519: "default", - RoleID: "roleID", - SecretID: auth.SecretID{FromEnv: "secretID"}, - AppRole: "auth/approle", - IsWrappingToken: false, - }, - false, - }, { "ok mandatory PKIRole PKIRoleEd25519", - `{"PKIRoleDefault": "role", "PKIRoleEd25519": "ed25519" , "RoleID": "roleID", "SecretID": {"FromEnv": "secretID"}}`, + `{"PKIRoleDefault": "role", "PKIRoleEd25519": "ed25519"}`, &VaultOptions{ - PKI: "pki", - PKIRoleDefault: "role", - PKIRoleRSA: "role", - PKIRoleEC: "role", - PKIRoleEd25519: "ed25519", - RoleID: "roleID", - SecretID: auth.SecretID{FromEnv: "secretID"}, - AppRole: "auth/approle", - IsWrappingToken: false, + PKIMountPath: "pki", + PKIRoleDefault: "role", + PKIRoleRSA: "role", + PKIRoleEC: "role", + PKIRoleEd25519: "ed25519", }, false, }, { "ok mandatory PKIRole PKIRoleEC", - `{"PKIRoleDefault": "role", "PKIRoleEC": "ec" , "RoleID": "roleID", "SecretID": {"FromEnv": "secretID"}}`, + `{"PKIRoleDefault": "role", "PKIRoleEC": "ec"}`, &VaultOptions{ - PKI: "pki", - PKIRoleDefault: "role", - PKIRoleRSA: "role", - PKIRoleEC: "ec", - PKIRoleEd25519: "role", - RoleID: "roleID", - SecretID: auth.SecretID{FromEnv: "secretID"}, - AppRole: "auth/approle", - IsWrappingToken: false, + PKIMountPath: "pki", + PKIRoleDefault: "role", + PKIRoleRSA: "role", + PKIRoleEC: "ec", + PKIRoleEd25519: "role", }, false, }, { "ok mandatory PKIRole PKIRoleRSA", - `{"PKIRoleDefault": "role", "PKIRoleRSA": "rsa" , "RoleID": "roleID", "SecretID": {"FromEnv": "secretID"}}`, + `{"PKIRoleDefault": "role", "PKIRoleRSA": "rsa"}`, &VaultOptions{ - PKI: "pki", - PKIRoleDefault: "role", - PKIRoleRSA: "rsa", - PKIRoleEC: "role", - PKIRoleEd25519: "role", - RoleID: "roleID", - SecretID: auth.SecretID{FromEnv: "secretID"}, - AppRole: "auth/approle", - IsWrappingToken: false, + PKIMountPath: "pki", + PKIRoleDefault: "role", + PKIRoleRSA: "rsa", + PKIRoleEC: "role", + PKIRoleEd25519: "role", }, false, }, { "ok mandatory PKIRoleRSA PKIRoleEC PKIRoleEd25519", - `{"PKIRoleRSA": "rsa", "PKIRoleEC": "ec", "PKIRoleEd25519": "ed25519", "RoleID": "roleID", "SecretID": {"FromEnv": "secretID"}}`, + `{"PKIRoleRSA": "rsa", "PKIRoleEC": "ec", "PKIRoleEd25519": "ed25519"}`, &VaultOptions{ - PKI: "pki", - PKIRoleDefault: "default", - PKIRoleRSA: "rsa", - PKIRoleEC: "ec", - PKIRoleEd25519: "ed25519", - RoleID: "roleID", - SecretID: auth.SecretID{FromEnv: "secretID"}, - AppRole: "auth/approle", - IsWrappingToken: false, + PKIMountPath: "pki", + PKIRoleDefault: "default", + PKIRoleRSA: "rsa", + PKIRoleEC: "ec", + PKIRoleEd25519: "ed25519", }, false, }, { "ok mandatory PKIRoleRSA PKIRoleEC PKIRoleEd25519 with useless PKIRoleDefault", - `{"PKIRoleDefault": "role", "PKIRoleRSA": "rsa", "PKIRoleEC": "ec", "PKIRoleEd25519": "ed25519", "RoleID": "roleID", "SecretID": {"FromEnv": "secretID"}}`, + `{"PKIRoleDefault": "role", "PKIRoleRSA": "rsa", "PKIRoleEC": "ec", "PKIRoleEd25519": "ed25519"}`, &VaultOptions{ - PKI: "pki", - PKIRoleDefault: "role", - PKIRoleRSA: "rsa", - PKIRoleEC: "ec", - PKIRoleEd25519: "ed25519", - RoleID: "roleID", - SecretID: auth.SecretID{FromEnv: "secretID"}, - AppRole: "auth/approle", - IsWrappingToken: false, + PKIMountPath: "pki", + PKIRoleDefault: "role", + PKIRoleRSA: "rsa", + PKIRoleEC: "ec", + PKIRoleEd25519: "ed25519", }, false, }, - { - "ok mandatory with AppRole", - `{"AppRole": "test", "RoleID": "roleID", "SecretID": {"FromString": "secretID"}}`, - &VaultOptions{ - PKI: "pki", - PKIRoleDefault: "default", - PKIRoleRSA: "default", - PKIRoleEC: "default", - PKIRoleEd25519: "default", - RoleID: "roleID", - SecretID: auth.SecretID{FromString: "secretID"}, - AppRole: "test", - IsWrappingToken: false, - }, - false, - }, - { - "ok mandatory with IsWrappingToken", - `{"IsWrappingToken": true, "RoleID": "roleID", "SecretID": {"FromString": "secretID"}}`, - &VaultOptions{ - PKI: "pki", - PKIRoleDefault: "default", - PKIRoleRSA: "default", - PKIRoleEC: "default", - PKIRoleEd25519: "default", - RoleID: "roleID", - SecretID: auth.SecretID{FromString: "secretID"}, - AppRole: "auth/approle", - IsWrappingToken: true, - }, - false, - }, - { - "fail with SecretID FromFail", - `{"RoleID": "roleID", "SecretID": {"FromFail": "secretID"}}`, - nil, - true, - }, - { - "fail with SecretID empty FromEnv", - `{"RoleID": "roleID", "SecretID": {"FromEnv": ""}}`, - nil, - true, - }, - { - "fail with SecretID empty FromFile", - `{"RoleID": "roleID", "SecretID": {"FromFile": ""}}`, - nil, - true, - }, - { - "fail with SecretID empty FromString", - `{"RoleID": "roleID", "SecretID": {"FromString": ""}}`, - nil, - true, - }, - { - "fail mandatory with SecretID FromFail", - `{"RoleID": "roleID", "SecretID": {"FromFail": "secretID"}}`, - nil, - true, - }, - { - "fail missing RoleID", - `{"SecretID": {"FromString": "secretID"}}`, - nil, - true, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/db/db.go b/db/db.go index 8cd1db0f..05f10793 100644 --- a/db/db.go +++ b/db/db.go @@ -1,6 +1,7 @@ package db import ( + "context" "crypto/x509" "encoding/json" "strconv" @@ -56,6 +57,29 @@ type AuthDB interface { Shutdown() error } +type dbKey struct{} + +// NewContext adds the given authority database to the context. +func NewContext(ctx context.Context, db AuthDB) context.Context { + return context.WithValue(ctx, dbKey{}, db) +} + +// FromContext returns the current authority database from the given context. +func FromContext(ctx context.Context) (db AuthDB, ok bool) { + db, ok = ctx.Value(dbKey{}).(AuthDB) + return +} + +// MustFromContext returns the current database from the given context. It +// will panic if it's not in the context. +func MustFromContext(ctx context.Context) AuthDB { + if db, ok := FromContext(ctx); !ok { + panic("authority database is not in the context") + } else { + return db + } +} + // CertificateStorer is an extension of AuthDB that allows to store // certificates. type CertificateStorer interface { diff --git a/go.mod b/go.mod index 8b66f470..0b772018 100644 --- a/go.mod +++ b/go.mod @@ -29,6 +29,7 @@ require ( github.com/googleapis/gax-go/v2 v2.1.1 github.com/hashicorp/vault/api v1.3.1 github.com/hashicorp/vault/api/auth/approle v0.1.1 + github.com/hashicorp/vault/api/auth/kubernetes v0.1.0 github.com/jhump/protoreflect v1.9.0 // indirect github.com/mattn/go-colorable v0.1.8 // indirect github.com/mattn/go-isatty v0.0.13 // indirect diff --git a/go.sum b/go.sum index 4780111e..d76648c2 100644 --- a/go.sum +++ b/go.sum @@ -449,6 +449,8 @@ github.com/hashicorp/vault/api v1.3.1 h1:pkDkcgTh47PRjY1NEFeofqR4W/HkNUi9qIakESO github.com/hashicorp/vault/api v1.3.1/go.mod h1:QeJoWxMFt+MsuWcYhmwRLwKEXrjwAFFywzhptMsTIUw= github.com/hashicorp/vault/api/auth/approle v0.1.1 h1:R5yA+xcNvw1ix6bDuWOaLOq2L4L77zDCVsethNw97xQ= github.com/hashicorp/vault/api/auth/approle v0.1.1/go.mod h1:mHOLgh//xDx4dpqXoq6tS8Ob0FoCFWLU2ibJ26Lfmag= +github.com/hashicorp/vault/api/auth/kubernetes v0.1.0 h1:6BtyahbF4aQp8gg3ww0A/oIoqzbhpNP1spXU3nHE0n0= +github.com/hashicorp/vault/api/auth/kubernetes v0.1.0/go.mod h1:Pdgk78uIs0mgDOLvc3a+h/vYIT9rznw2sz+ucuH9024= github.com/hashicorp/vault/sdk v0.3.0 h1:kR3dpxNkhh/wr6ycaJYqp6AFT/i2xaftbfnwZduTKEY= github.com/hashicorp/vault/sdk v0.3.0/go.mod h1:aZ3fNuL5VNydQk8GcLJ2TV8YCRVvyaakYkhZRoVuhj0= github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb h1:b5rjCoWHc7eqmAS4/qyk21ZsHyb6Mxv/jykxvNTkU4M= diff --git a/scep/api/api.go b/scep/api/api.go index fcabfc58..b738a933 100644 --- a/scep/api/api.go +++ b/scep/api/api.go @@ -38,8 +38,8 @@ type request struct { Message []byte } -// response is a SCEP server response. -type response struct { +// Response is a SCEP server Response. +type Response struct { Operation string CACertNum int Data []byte @@ -52,25 +52,48 @@ type handler struct { auth *scep.Authority } +// Route traffic and implement the Router interface. +// +// Deprecated: use scep.Route(r api.Router) +func (h *handler) Route(r api.Router) { + route(r, func(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := scep.NewContext(r.Context(), h.auth) + next(w, r.WithContext(ctx)) + } + }) +} + // New returns a new SCEP API router. +// +// Deprecated: use scep.Route(r api.Router) func New(auth *scep.Authority) api.RouterHandler { - return &handler{ - auth: auth, - } + return &handler{auth: auth} } // Route traffic and implement the Router interface. -func (h *handler) Route(r api.Router) { - getLink := h.auth.GetLinkExplicit - r.MethodFunc(http.MethodGet, getLink("{provisionerName}/*", false, nil), h.lookupProvisioner(h.Get)) - r.MethodFunc(http.MethodGet, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Get)) - r.MethodFunc(http.MethodPost, getLink("{provisionerName}/*", false, nil), h.lookupProvisioner(h.Post)) - r.MethodFunc(http.MethodPost, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Post)) +func Route(r api.Router) { + route(r, nil) } -// Get handles all SCEP GET requests -func (h *handler) Get(w http.ResponseWriter, r *http.Request) { +func route(r api.Router, middleware func(next http.HandlerFunc) http.HandlerFunc) { + getHandler := lookupProvisioner(Get) + postHandler := lookupProvisioner(Post) + + // For backward compatibility. + if middleware != nil { + getHandler = middleware(getHandler) + postHandler = middleware(postHandler) + } + + r.MethodFunc(http.MethodGet, "/{provisionerName}/*", getHandler) + r.MethodFunc(http.MethodGet, "/{provisionerName}", getHandler) + r.MethodFunc(http.MethodPost, "/{provisionerName}/*", postHandler) + r.MethodFunc(http.MethodPost, "/{provisionerName}", postHandler) +} +// Get handles all SCEP GET requests +func Get(w http.ResponseWriter, r *http.Request) { req, err := decodeRequest(r) if err != nil { fail(w, fmt.Errorf("invalid scep get request: %w", err)) @@ -78,15 +101,15 @@ func (h *handler) Get(w http.ResponseWriter, r *http.Request) { } ctx := r.Context() - var res response + var res Response switch req.Operation { case opnGetCACert: - res, err = h.GetCACert(ctx) + res, err = GetCACert(ctx) case opnGetCACaps: - res, err = h.GetCACaps(ctx) + res, err = GetCACaps(ctx) case opnPKIOperation: - res, err = h.PKIOperation(ctx, req) + res, err = PKIOperation(ctx, req) default: err = fmt.Errorf("unknown operation: %s", req.Operation) } @@ -100,20 +123,17 @@ func (h *handler) Get(w http.ResponseWriter, r *http.Request) { } // Post handles all SCEP POST requests -func (h *handler) Post(w http.ResponseWriter, r *http.Request) { - +func Post(w http.ResponseWriter, r *http.Request) { req, err := decodeRequest(r) if err != nil { fail(w, fmt.Errorf("invalid scep post request: %w", err)) return } - ctx := r.Context() - var res response - + var res Response switch req.Operation { case opnPKIOperation: - res, err = h.PKIOperation(ctx, req) + res, err = PKIOperation(r.Context(), req) default: err = fmt.Errorf("unknown operation: %s", req.Operation) } @@ -127,7 +147,6 @@ func (h *handler) Post(w http.ResponseWriter, r *http.Request) { } func decodeRequest(r *http.Request) (request, error) { - defer r.Body.Close() method := r.Method @@ -179,9 +198,8 @@ func decodeRequest(r *http.Request) (request, error) { // lookupProvisioner loads the provisioner associated with the request. // Responds 404 if the provisioner does not exist. -func (h *handler) lookupProvisioner(next http.HandlerFunc) http.HandlerFunc { +func lookupProvisioner(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - name := chi.URLParam(r, "provisionerName") provisionerName, err := url.PathUnescape(name) if err != nil { @@ -189,7 +207,9 @@ func (h *handler) lookupProvisioner(next http.HandlerFunc) http.HandlerFunc { return } - p, err := h.auth.LoadProvisionerByName(provisionerName) + ctx := r.Context() + auth := scep.MustFromContext(ctx) + p, err := auth.LoadProvisionerByName(provisionerName) if err != nil { fail(w, err) return @@ -201,25 +221,24 @@ func (h *handler) lookupProvisioner(next http.HandlerFunc) http.HandlerFunc { return } - ctx := r.Context() ctx = context.WithValue(ctx, scep.ProvisionerContextKey, scep.Provisioner(prov)) next(w, r.WithContext(ctx)) } } // GetCACert returns the CA certificates in a SCEP response -func (h *handler) GetCACert(ctx context.Context) (response, error) { - - certs, err := h.auth.GetCACertificates(ctx) +func GetCACert(ctx context.Context) (Response, error) { + auth := scep.MustFromContext(ctx) + certs, err := auth.GetCACertificates(ctx) if err != nil { - return response{}, err + return Response{}, err } if len(certs) == 0 { - return response{}, errors.New("missing CA cert") + return Response{}, errors.New("missing CA cert") } - res := response{ + res := Response{ Operation: opnGetCACert, CACertNum: len(certs), } @@ -232,7 +251,7 @@ func (h *handler) GetCACert(ctx context.Context) (response, error) { // not signed or encrypted data has to be returned. data, err := microscep.DegenerateCertificates(certs) if err != nil { - return response{}, err + return Response{}, err } res.Data = data } @@ -241,11 +260,11 @@ func (h *handler) GetCACert(ctx context.Context) (response, error) { } // GetCACaps returns the CA capabilities in a SCEP response -func (h *handler) GetCACaps(ctx context.Context) (response, error) { +func GetCACaps(ctx context.Context) (Response, error) { + auth := scep.MustFromContext(ctx) + caps := auth.GetCACaps(ctx) - caps := h.auth.GetCACaps(ctx) - - res := response{ + res := Response{ Operation: opnGetCACaps, Data: formatCapabilities(caps), } @@ -254,13 +273,12 @@ func (h *handler) GetCACaps(ctx context.Context) (response, error) { } // PKIOperation performs PKI operations and returns a SCEP response -func (h *handler) PKIOperation(ctx context.Context, req request) (response, error) { - +func PKIOperation(ctx context.Context, req request) (Response, error) { // parse the message using microscep implementation microMsg, err := microscep.ParsePKIMessage(req.Message) if err != nil { // return the error, because we can't use the msg for creating a CertRep - return response{}, err + return Response{}, err } // this is essentially doing the same as microscep.ParsePKIMessage, but @@ -268,7 +286,7 @@ func (h *handler) PKIOperation(ctx context.Context, req request) (response, erro // wrapper for the microscep implementation. p7, err := pkcs7.Parse(microMsg.Raw) if err != nil { - return response{}, err + return Response{}, err } // copy over properties to our internal PKIMessage @@ -280,8 +298,9 @@ func (h *handler) PKIOperation(ctx context.Context, req request) (response, erro P7: p7, } - if err := h.auth.DecryptPKIEnvelope(ctx, msg); err != nil { - return response{}, err + auth := scep.MustFromContext(ctx) + if err := auth.DecryptPKIEnvelope(ctx, msg); err != nil { + return Response{}, err } // NOTE: at this point we have sufficient information for returning nicely signed CertReps @@ -293,13 +312,13 @@ func (h *handler) PKIOperation(ctx context.Context, req request) (response, erro // a certificate exists; then it will use RenewalReq. Adding the challenge check here may be a small breaking change for clients. // We'll have to see how it works out. if msg.MessageType == microscep.PKCSReq || msg.MessageType == microscep.RenewalReq { - challengeMatches, err := h.auth.MatchChallengePassword(ctx, msg.CSRReqMessage.ChallengePassword) + challengeMatches, err := auth.MatchChallengePassword(ctx, msg.CSRReqMessage.ChallengePassword) if err != nil { - return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("error when checking password")) + return createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("error when checking password")) } if !challengeMatches { // TODO: can this be returned safely to the client? In the end, if the password was correct, that gains a bit of info too. - return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("wrong password provided")) + return createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("wrong password provided")) } } @@ -311,12 +330,12 @@ func (h *handler) PKIOperation(ctx context.Context, req request) (response, erro // Authentication by the (self-signed) certificate with an optional challenge is required; supporting renewals incl. verification // of the client cert is not. - certRep, err := h.auth.SignCSR(ctx, csr, msg) + certRep, err := auth.SignCSR(ctx, csr, msg) if err != nil { - return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, fmt.Errorf("error when signing new certificate: %w", err)) + return createFailureResponse(ctx, csr, msg, microscep.BadRequest, fmt.Errorf("error when signing new certificate: %w", err)) } - res := response{ + res := Response{ Operation: opnPKIOperation, Data: certRep.Raw, Certificate: certRep.Certificate, @@ -330,7 +349,7 @@ func formatCapabilities(caps []string) []byte { } // writeResponse writes a SCEP response back to the SCEP client. -func writeResponse(w http.ResponseWriter, res response) { +func writeResponse(w http.ResponseWriter, res Response) { if res.Error != nil { log.Error(w, res.Error) @@ -350,19 +369,20 @@ func fail(w http.ResponseWriter, err error) { http.Error(w, err.Error(), http.StatusInternalServerError) } -func (h *handler) createFailureResponse(ctx context.Context, csr *x509.CertificateRequest, msg *scep.PKIMessage, info microscep.FailInfo, failError error) (response, error) { - certRepMsg, err := h.auth.CreateFailureResponse(ctx, csr, msg, scep.FailInfoName(info), failError.Error()) +func createFailureResponse(ctx context.Context, csr *x509.CertificateRequest, msg *scep.PKIMessage, info microscep.FailInfo, failError error) (Response, error) { + auth := scep.MustFromContext(ctx) + certRepMsg, err := auth.CreateFailureResponse(ctx, csr, msg, scep.FailInfoName(info), failError.Error()) if err != nil { - return response{}, err + return Response{}, err } - return response{ + return Response{ Operation: opnPKIOperation, Data: certRepMsg.Raw, Error: failError, }, nil } -func contentHeader(r response) string { +func contentHeader(r Response) string { switch r.Operation { default: return "text/plain" diff --git a/scep/authority.go b/scep/authority.go index 71f92152..7dbbb8c5 100644 --- a/scep/authority.go +++ b/scep/authority.go @@ -27,6 +27,29 @@ type Authority struct { signAuth SignAuthority } +type authorityKey struct{} + +// NewContext adds the given authority to the context. +func NewContext(ctx context.Context, a *Authority) context.Context { + return context.WithValue(ctx, authorityKey{}, a) +} + +// FromContext returns the current authority from the given context. +func FromContext(ctx context.Context) (a *Authority, ok bool) { + a, ok = ctx.Value(authorityKey{}).(*Authority) + return +} + +// MustFromContext returns the current authority from the given context. It will +// panic if the authority is not in the context. +func MustFromContext(ctx context.Context) *Authority { + if a, ok := FromContext(ctx); !ok { + panic("scep authority is not in the context") + } else { + return a + } +} + // AuthorityOptions required to create a new SCEP Authority. type AuthorityOptions struct { // Service provides the certificate chain, the signer and the decrypter to the Authority @@ -163,7 +186,6 @@ func (a *Authority) GetCACertificates(ctx context.Context) ([]*x509.Certificate, // DecryptPKIEnvelope decrypts an enveloped message func (a *Authority) DecryptPKIEnvelope(ctx context.Context, msg *PKIMessage) error { - p7c, err := pkcs7.Parse(msg.P7.Content) if err != nil { return fmt.Errorf("error parsing pkcs7 content: %w", err) @@ -210,7 +232,6 @@ func (a *Authority) DecryptPKIEnvelope(ctx context.Context, msg *PKIMessage) err // SignCSR creates an x509.Certificate based on a CSR template and Cert Authority credentials // returns a new PKIMessage with CertRep data func (a *Authority) SignCSR(ctx context.Context, csr *x509.CertificateRequest, msg *PKIMessage) (*PKIMessage, error) { - // TODO: intermediate storage of the request? In SCEP it's possible to request a csr/certificate // to be signed, which can be performed asynchronously / out-of-band. In that case a client can // poll for the status. It seems to be similar as what can happen in ACME, so might want to model @@ -432,7 +453,6 @@ func (a *Authority) CreateFailureResponse(ctx context.Context, csr *x509.Certifi // MatchChallengePassword verifies a SCEP challenge password func (a *Authority) MatchChallengePassword(ctx context.Context, password string) (bool, error) { - p, err := provisionerFromContext(ctx) if err != nil { return false, err