From 8d2ebcfd497a1f062512a5f8222c1c93bd935221 Mon Sep 17 00:00:00 2001 From: max furman Date: Fri, 12 Mar 2021 00:16:48 -0800 Subject: [PATCH] [acme db interface] more unit tests --- acme/account.go | 10 -- acme/api/account.go | 20 ++- acme/api/account_test.go | 28 +-- acme/api/order_test.go | 371 +++++++++++++++++++++++++++++---------- acme/order.go | 4 +- 5 files changed, 314 insertions(+), 119 deletions(-) diff --git a/acme/account.go b/acme/account.go index 354ebdc7..cb60e21d 100644 --- a/acme/account.go +++ b/acme/account.go @@ -27,16 +27,6 @@ func (a *Account) ToLog() (interface{}, error) { return string(b), nil } -// GetID returns the account ID. -func (a *Account) GetID() string { - return a.ID -} - -// GetKey returns the JWK associated with the account. -func (a *Account) GetKey() *jose.JSONWebKey { - return a.Key -} - // IsValid returns true if the Account is valid. func (a *Account) IsValid() bool { return Status(a.Status) == StatusValid diff --git a/acme/api/account.go b/acme/api/account.go index 2e15ad40..c7f3d11a 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -153,15 +153,17 @@ func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) { api.WriteError(w, err) return } - var err error - // If neither the status nor the contacts are being updated then ignore - // the updates and return 200. This conforms with the behavior detailed - // in the ACME spec (https://tools.ietf.org/html/rfc8555#section-7.3.2). - acc.Status = uar.Status - acc.Contact = uar.Contact - if err = h.db.UpdateAccount(ctx, acc); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error updating account")) - return + if len(uar.Status) > 0 || len(uar.Contact) > 0 { + if len(uar.Status) > 0 { + acc.Status = uar.Status + } else if len(uar.Contact) > 0 { + acc.Contact = uar.Contact + } + + if err := h.db.UpdateAccount(ctx, acc); err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error updating account")) + return + } } } diff --git a/acme/api/account_test.go b/acme/api/account_test.go index d8fdff84..28abffe1 100644 --- a/acme/api/account_test.go +++ b/acme/api/account_test.go @@ -168,16 +168,22 @@ func TestUpdateAccountRequest_Validate(t *testing.T) { } func TestHandler_GetOrdersByAccountID(t *testing.T) { - oids := []string{ - "https://ca.smallstep.com/acme/order/foo", - "https://ca.smallstep.com/acme/order/bar", + oids := []string{"foo", "bar"} + oidURLs := []string{ + "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/order/foo", + "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/order/bar", } accID := "account-id" // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("accID", accID) - url := fmt.Sprintf("http://ca.smallstep.com/acme/account/%s/orders", accID) + + prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + + url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/account/%s/orders", provName, accID) type test struct { db acme.DB @@ -189,15 +195,15 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) { "fail/no-account": func(t *testing.T) test { return test{ db: &acme.MockDB{}, + ctx: context.Background(), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), accContextKey, nil) return test{ db: &acme.MockDB{}, - ctx: ctx, + ctx: context.WithValue(context.Background(), accContextKey, nil), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } @@ -213,7 +219,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) { err: acme.NewError(acme.ErrorUnauthorizedType, "account ID does not match url param"), } }, - "fail/getOrdersByAccount-error": func(t *testing.T) test { + "fail/db.GetOrdersByAccountID-error": func(t *testing.T) test { acc := &acme.Account{ID: accID} ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -230,6 +236,8 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) { acc := &acme.Account{ID: accID} ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) return test{ db: &acme.MockDB{ MockGetOrdersByAccountID: func(ctx context.Context, id string) ([]string, error) { @@ -245,7 +253,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db} + h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -268,7 +276,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) { assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { - expB, err := json.Marshal(oids) + expB, err := json.Marshal(oidURLs) assert.FatalError(t, err) assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) @@ -558,7 +566,7 @@ func TestHandler_GetUpdateAccount(t *testing.T) { err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), } }, - "fail/update-error": func(t *testing.T) test { + "fail/db.UpdateAccount-error": func(t *testing.T) test { uar := &UpdateAccountRequest{ Status: "deactivated", } diff --git a/acme/api/order_test.go b/acme/api/order_test.go index 610713b6..b6783e34 100644 --- a/acme/api/order_test.go +++ b/acme/api/order_test.go @@ -19,7 +19,7 @@ import ( "go.step.sm/crypto/pemutil" ) -func TestNewOrderRequestValidate(t *testing.T) { +func TestNewOrderRequest_Validate(t *testing.T) { type test struct { nor *NewOrderRequest nbf, naf time.Time @@ -148,12 +148,12 @@ func TestFinalizeRequestValidate(t *testing.T) { } func TestHandler_GetOrder(t *testing.T) { - expiry := time.Now().UTC().Add(6 * time.Hour) - nbf := time.Now().UTC() - naf := time.Now().UTC().Add(24 * time.Hour) + now := clock.Now() + nbf := now + naf := now.Add(24 * time.Hour) + expiry := now.Add(-time.Hour) o := acme.Order{ ID: "orderID", - Expires: expiry, NotBefore: nbf, NotAfter: naf, Identifiers: []acme.Identifier{ @@ -166,8 +166,15 @@ func TestHandler_GetOrder(t *testing.T) { Value: "*.smallstep.com", }, }, - Status: "pending", - AuthorizationURLs: []string{"foo", "bar"}, + Expires: expiry, + Status: acme.StatusInvalid, + Error: acme.NewError(acme.ErrorMalformedType, "order has expired"), + AuthorizationURLs: []string{ + "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/foo", + "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/bar", + "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/baz", + }, + FinalizeURL: "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/order/orderID/finalize", } // Request with chi context @@ -181,7 +188,6 @@ func TestHandler_GetOrder(t *testing.T) { type test struct { db acme.DB - linker Linker ctx context.Context statusCode int err *acme.Error @@ -203,8 +209,27 @@ func TestHandler_GetOrder(t *testing.T) { err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, - "fail/getOrder-error": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} + "fail/no-provisioner": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), accContextKey, acc) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("provisioner does not exist"), + } + }, + "fail/nil-provisioner": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, nil) + ctx = context.WithValue(ctx, accContextKey, acc) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("provisioner does not exist"), + } + }, + "fail/db.GetOrder-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -217,8 +242,64 @@ func TestHandler_GetOrder(t *testing.T) { err: acme.NewErrorISE("force"), } }, + "fail/account-id-mismatch": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { + return &acme.Order{AccountID: "foo"}, nil + }, + }, + ctx: ctx, + statusCode: 401, + err: acme.NewError(acme.ErrorUnauthorizedType, "account id mismatch"), + } + }, + "fail/provisioner-id-mismatch": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { + return &acme.Order{AccountID: "accountID", ProvisionerID: "bar"}, nil + }, + }, + ctx: ctx, + statusCode: 401, + err: acme.NewError(acme.ErrorUnauthorizedType, "provisioner id mismatch"), + } + }, + "fail/order-update-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { + return &acme.Order{ + AccountID: "accountID", + ProvisionerID: "acme/test@acme-provisioner.com", + Expires: clock.Now().Add(-time.Hour), + Status: acme.StatusReady, + }, nil + }, + MockUpdateOrder: func(ctx context.Context, o *acme.Order) error { + return acme.NewErrorISE("force") + }, + }, + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("force"), + } + }, "ok": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} + acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -226,11 +307,31 @@ func TestHandler_GetOrder(t *testing.T) { return test{ db: &acme.MockDB{ MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { - assert.Equals(t, id, o.ID) - return &o, nil + return &acme.Order{ + ID: "orderID", + AccountID: "accountID", + ProvisionerID: "acme/test@acme-provisioner.com", + Expires: expiry, + Status: acme.StatusReady, + AuthorizationIDs: []string{"foo", "bar", "baz"}, + NotBefore: nbf, + NotAfter: naf, + Identifiers: []acme.Identifier{ + { + Type: "dns", + Value: "example.com", + }, + { + Type: "dns", + Value: "*.smallstep.com", + }, + }, + }, nil + }, + MockUpdateOrder: func(ctx context.Context, o *acme.Order) error { + return nil }, }, - linker: NewLinker("dns", "acme"), ctx: ctx, statusCode: 200, } @@ -239,7 +340,7 @@ func TestHandler_GetOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: tc.linker, db: tc.db} + h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -264,6 +365,7 @@ func TestHandler_GetOrder(t *testing.T) { } else { expB, err := json.Marshal(o) assert.FatalError(t, err) + assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, res.Header["Location"], []string{url}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) @@ -272,7 +374,7 @@ func TestHandler_GetOrder(t *testing.T) { } } -func TestHandlerNewOrder(t *testing.T) { +func TestHandler_NewOrder(t *testing.T) { expiry := time.Now().UTC().Add(6 * time.Hour) nbf := time.Now().UTC().Add(5 * time.Hour) naf := nbf.Add(17 * time.Hour) @@ -297,7 +399,6 @@ func TestHandlerNewOrder(t *testing.T) { type test struct { db acme.DB - linker Linker ctx context.Context statusCode int err *acme.Error @@ -319,14 +420,23 @@ func TestHandlerNewOrder(t *testing.T) { err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, - "fail/no-payload": func(t *testing.T) test { + "fail/no-provisioner": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := context.WithValue(context.Background(), accContextKey, acc) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("provisioner expected in request context"), + } + }, + "fail/nil-provisioner": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, nil) ctx = context.WithValue(ctx, accContextKey, acc) return test{ ctx: ctx, statusCode: 500, - err: acme.NewErrorISE("payload expected in request context"), + err: acme.NewErrorISE("provisioner expected in request context"), } }, "fail/nil-payload": func(t *testing.T) test { @@ -408,10 +518,10 @@ func TestHandlerNewOrder(t *testing.T) { return test{ db: &acme.MockDB{ MockCreateOrder: func(ctx context.Context, o *acme.Order) error { + o.ID = "orderID" return nil }, }, - linker: NewLinker("dns", "acme"), ctx: ctx, statusCode: 201, } @@ -436,7 +546,6 @@ func TestHandlerNewOrder(t *testing.T) { return nil }, }, - linker: NewLinker("dns", "acme"), ctx: ctx, statusCode: 201, } @@ -445,7 +554,7 @@ func TestHandlerNewOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: tc.linker, db: tc.db} + h := &Handler{linker: NewLinker("dns", "prefix"), db: tc.db} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -481,26 +590,33 @@ func TestHandlerNewOrder(t *testing.T) { } func TestHandler_FinalizeOrder(t *testing.T) { - expiry := time.Now().UTC().Add(6 * time.Hour) - nbf := time.Now().UTC().Add(5 * time.Hour) - naf := nbf.Add(17 * time.Hour) + now := clock.Now() + nbf := now + naf := now.Add(24 * time.Hour) o := acme.Order{ ID: "orderID", - Expires: expiry, NotBefore: nbf, NotAfter: naf, Identifiers: []acme.Identifier{ - {Type: "dns", Value: "example.com"}, - {Type: "dns", Value: "bar.com"}, + { + Type: "dns", + Value: "example.com", + }, + { + Type: "dns", + Value: "*.smallstep.com", + }, }, - Status: "valid", - AuthorizationURLs: []string{"foo", "bar"}, - CertificateURL: "https://ca.smallstep.com/acme/certificate/certID", + Expires: naf, + Status: acme.StatusValid, + AuthorizationURLs: []string{ + "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/foo", + "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/bar", + "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/baz", + }, + FinalizeURL: "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/order/orderID/finalize", + CertificateURL: "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/certificate/certID", } - _csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr") - assert.FatalError(t, err) - csr, ok := _csr.(*x509.CertificateRequest) - assert.Fatal(t, ok) // Request with chi context chiCtx := chi.NewRouteContext() @@ -508,12 +624,22 @@ func TestHandler_FinalizeOrder(t *testing.T) { prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - url := fmt.Sprintf("%s/acme/%s/order/%s/finalize", + url := fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, o.ID) + _csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr") + assert.FatalError(t, err) + csr, ok := _csr.(*x509.CertificateRequest) + assert.Fatal(t, ok) + + nor := &FinalizeRequest{ + CSR: base64.RawURLEncoding.EncodeToString(csr.Raw), + } + payloadBytes, err := json.Marshal(nor) + assert.FatalError(t, err) + type test struct { db acme.DB - linker Linker ctx context.Context statusCode int err *acme.Error @@ -521,7 +647,6 @@ func TestHandler_FinalizeOrder(t *testing.T) { var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - db: &acme.MockDB{}, ctx: context.WithValue(context.Background(), provisionerContextKey, prov), statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), @@ -531,31 +656,49 @@ func TestHandler_FinalizeOrder(t *testing.T) { ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, nil) return test{ - db: &acme.MockDB{}, ctx: ctx, statusCode: 400, err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, - "fail/no-payload": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + "fail/no-provisioner": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), accContextKey, acc) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("provisioner does not exist"), + } + }, + "fail/nil-provisioner": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, nil) ctx = context.WithValue(ctx, accContextKey, acc) return test{ ctx: ctx, statusCode: 500, - err: acme.NewErrorISE("payload expected in request context"), + err: acme.NewErrorISE("provisioner does not exist"), + } + }, + "fail/no-payload": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), accContextKey, acc) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("payload does not exist"), } }, "fail/nil-payload": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} + acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ ctx: ctx, statusCode: 500, - err: acme.NewErrorISE("payload expected in request context"), + err: acme.NewErrorISE("paylod does not exist"), } }, "fail/unmarshal-payload-error": func(t *testing.T) test { @@ -583,62 +726,112 @@ func TestHandler_FinalizeOrder(t *testing.T) { err: acme.NewError(acme.ErrorMalformedType, "unable to parse csr: asn1: syntax error: sequence truncated"), } }, - "fail/FinalizeOrder-error": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} - nor := &FinalizeRequest{ - CSR: base64.RawURLEncoding.EncodeToString(csr.Raw), + "fail/db.GetOrder-error": func(t *testing.T) test { + + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockError: acme.NewErrorISE("force"), + }, + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("force"), } - b, err := json.Marshal(nor) - assert.FatalError(t, err) + }, + "fail/account-id-mismatch": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ db: &acme.MockDB{ - MockUpdateOrder: func(ctx context.Context, o *acme.Order) error { - /* - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, accID, acc.ID) - assert.Equals(t, id, o.ID) - assert.Equals(t, incsr.Raw, csr.Raw) - */ - return acme.NewError(acme.ErrorMalformedType, "force") + MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { + return &acme.Order{AccountID: "foo"}, nil }, }, ctx: ctx, - statusCode: 400, - err: acme.NewError(acme.ErrorMalformedType, "force"), + statusCode: 401, + err: acme.NewError(acme.ErrorUnauthorizedType, "account id mismatch"), } }, - "ok": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} - nor := &FinalizeRequest{ - CSR: base64.RawURLEncoding.EncodeToString(csr.Raw), + "fail/provisioner-id-mismatch": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { + return &acme.Order{AccountID: "accountID", ProvisionerID: "bar"}, nil + }, + }, + ctx: ctx, + statusCode: 401, + err: acme.NewError(acme.ErrorUnauthorizedType, "provisioner id mismatch"), } - b, err := json.Marshal(nor) - assert.FatalError(t, err) + }, + "fail/order-finalize-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - linker: NewLinker("dns", "acme"), db: &acme.MockDB{ + MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { + return &acme.Order{ + AccountID: "accountID", + ProvisionerID: "acme/test@acme-provisioner.com", + Expires: clock.Now().Add(-time.Hour), + Status: acme.StatusReady, + }, nil + }, MockUpdateOrder: func(ctx context.Context, o *acme.Order) error { - /* - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, accID, acc.ID) - assert.Equals(t, id, o.ID) - assert.Equals(t, incsr.Raw, csr.Raw) - return &o, nil - */ - return nil + return acme.NewErrorISE("force") + }, + }, + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("force"), + } + }, + "ok": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { + return &acme.Order{ + ID: "orderID", + AccountID: "accountID", + ProvisionerID: "acme/test@acme-provisioner.com", + Expires: naf, + Status: acme.StatusValid, + AuthorizationIDs: []string{"foo", "bar", "baz"}, + NotBefore: nbf, + NotAfter: naf, + Identifiers: []acme.Identifier{ + { + Type: "dns", + Value: "example.com", + }, + { + Type: "dns", + Value: "*.smallstep.com", + }, + }, + CertificateID: "certID", + }, nil }, }, ctx: ctx, @@ -649,7 +842,7 @@ func TestHandler_FinalizeOrder(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{linker: tc.linker, db: tc.db} + h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -674,10 +867,12 @@ func TestHandler_FinalizeOrder(t *testing.T) { } else { expB, err := json.Marshal(o) assert.FatalError(t, err) + + ro := new(acme.Order) + err = json.Unmarshal(body, ro) + assert.Equals(t, bytes.TrimSpace(body), expB) - assert.Equals(t, res.Header["Location"], - []string{fmt.Sprintf("%s/acme/%s/order/%s", - baseURL, provName, o.ID)}) + assert.Equals(t, res.Header["Location"], []string{url}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) diff --git a/acme/order.go b/acme/order.go index 7b0b2d4d..a2c89fe7 100644 --- a/acme/order.go +++ b/acme/order.go @@ -20,6 +20,7 @@ type Identifier struct { // Order contains order metadata for the ACME protocol order type. type Order struct { + ID string `json:"id"` Status Status `json:"status"` Expires time.Time `json:"expires,omitempty"` Identifiers []Identifier `json:"identifiers"` @@ -31,7 +32,6 @@ type Order struct { FinalizeURL string `json:"finalize"` CertificateID string `json:"-"` CertificateURL string `json:"certificate,omitempty"` - ID string `json:"-"` AccountID string `json:"-"` ProvisionerID string `json:"-"` DefaultDuration time.Duration `json:"-"` @@ -50,7 +50,7 @@ func (o *Order) ToLog() (interface{}, error) { // UpdateStatus updates the ACME Order Status if necessary. // Changes to the order are saved using the database interface. func (o *Order) UpdateStatus(ctx context.Context, db DB) error { - now := time.Now().UTC() + now := clock.Now() switch o.Status { case StatusInvalid: