From c0a9f247989504566583d1391a19cdfae76ae07f Mon Sep 17 00:00:00 2001 From: max furman Date: Wed, 24 Mar 2021 16:50:35 -0700 Subject: [PATCH] add authorization and order unit tests --- acme/account_test.go | 759 ++-------------------------------- acme/authorization.go | 1 + acme/authorization_test.go | 150 +++++++ acme/authz_test.go | 824 ------------------------------------- acme/order.go | 9 +- acme/order_test.go | 256 ++++++++++++ 6 files changed, 451 insertions(+), 1548 deletions(-) create mode 100644 acme/authorization_test.go delete mode 100644 acme/authz_test.go diff --git a/acme/account_test.go b/acme/account_test.go index 45b86f20..5625c3dc 100644 --- a/acme/account_test.go +++ b/acme/account_test.go @@ -1,764 +1,81 @@ package acme import ( - "fmt" - "time" + "crypto" + "encoding/base64" + "testing" - "github.com/smallstep/certificates/authority/provisioner" + "github.com/pkg/errors" + "github.com/smallstep/assert" + "go.step.sm/crypto/jose" ) -var ( - defaultDisableRenewal = false - globalProvisionerClaims = provisioner.Claims{ - MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, - MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DisableRenewal: &defaultDisableRenewal, - } -) - -func newProv() Provisioner { - // Initialize provisioners - p := &provisioner.ACME{ - Type: "ACME", - Name: "test@acme-provisioner.com", - } - if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil { - fmt.Printf("%v", err) - } - return p -} - -/* -func newAcc() (*Account, error) { - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - if err != nil { - return nil, err - } - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, true, nil - }, - } - return newAccount(mockdb, AccountOptions{ - Key: jwk, Contact: []string{"foo", "bar"}, - }) -} -*/ - -/* -func TestGetAccountByID(t *testing.T) { - type test struct { - id string - db nosql.DB - acc *Account - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/not-found": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - return test{ - acc: acc, - id: acc.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - }, - err: NewError(ErrorMalformedType, "account %s not found: not found", acc.ID), - } - }, - "fail/db-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - return test{ - acc: acc, - id: acc.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: NewErrorISE("error loading account %s: force", acc.ID), - } - }, - "fail/unmarshal-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - return test{ - acc: acc, - id: acc.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - return nil, nil - }, - }, - err: ServerInternalErr(errors.New("error unmarshaling account: unexpected end of JSON input")), - } - }, - "ok": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - return test{ - acc: acc, - id: acc.ID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - return b, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acc, err := getAccountByID(tc.db, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.acc.ID, acc.ID) - assert.Equals(t, tc.acc.Status, acc.Status) - assert.Equals(t, tc.acc.Created, acc.Created) - assert.Equals(t, tc.acc.Deactivated, acc.Deactivated) - assert.Equals(t, tc.acc.Contact, acc.Contact) - assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID) - } - } - }) - } -} - -func TestGetAccountByKeyID(t *testing.T) { - type test struct { - kid string - db nosql.DB - acc *account - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/kid-not-found": func(t *testing.T) test { - return test{ - kid: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - }, - err: MalformedErr(errors.Errorf("account with key id foo not found: not found")), - } - }, - "fail/db-error": func(t *testing.T) test { - return test{ - kid: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error loading key-account index: force")), - } - }, - "fail/getAccount-error": func(t *testing.T) test { - count := 0 - return test{ - kid: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - if count == 0 { - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte("foo")) - count++ - return []byte("bar"), nil - } - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error loading account bar: force")), - } - }, - "ok": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - count := 0 - return test{ - kid: acc.Key.KeyID, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - var ret []byte - switch count { - case 0: - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte(acc.Key.KeyID)) - ret = []byte(acc.ID) - case 1: - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - ret = b - } - count++ - return ret, nil - }, - }, - acc: acc, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if acc, err := getAccountByKeyID(tc.db, tc.kid); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.acc.ID, acc.ID) - assert.Equals(t, tc.acc.Status, acc.Status) - assert.Equals(t, tc.acc.Created, acc.Created) - assert.Equals(t, tc.acc.Deactivated, acc.Deactivated) - assert.Equals(t, tc.acc.Contact, acc.Contact) - assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID) - } - } - }) - } -} - -func TestAccountToACME(t *testing.T) { - dir := newDirectory("ca.smallstep.com", "acme") - prov := newProv() - provName := url.PathEscape(prov.GetName()) - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) - - type test struct { - acc *account - err *Error - } - tests := map[string]func(t *testing.T) test{ - "ok": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - return test{acc: acc} - }, - } - for name, run := range tests { - tc := run(t) - t.Run(name, func(t *testing.T) { - acmeAccount, err := tc.acc.toACME(ctx, nil, dir) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, acmeAccount.ID, tc.acc.ID) - assert.Equals(t, acmeAccount.Status, tc.acc.Status) - assert.Equals(t, acmeAccount.Contact, tc.acc.Contact) - assert.Equals(t, acmeAccount.Key.KeyID, tc.acc.Key.KeyID) - assert.Equals(t, acmeAccount.Orders, - fmt.Sprintf("%s/acme/%s/account/%s/orders", baseURL.String(), provName, tc.acc.ID)) - } - } - }) - } -} - -func TestAccountSave(t *testing.T) { - type test struct { - acc, old *account - db nosql.DB - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/old-nil/swap-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - return test{ - acc: acc, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error storing account: force")), - } - }, - "fail/old-nil/swap-false": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - return test{ - acc: acc, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), false, nil - }, - }, - err: ServerInternalErr(errors.New("error storing account; value has changed since last read")), - } - }, - "ok/old-nil": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - return test{ - acc: acc, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, nil) - assert.Equals(t, b, newval) - assert.Equals(t, bucket, accountTable) - assert.Equals(t, []byte(acc.ID), key) - return nil, true, nil - }, - }, - } - }, - "ok/old-not-nil": func(t *testing.T) test { - oldAcc, err := newAcc() - assert.FatalError(t, err) - acc, err := newAcc() - assert.FatalError(t, err) - - oldb, err := json.Marshal(oldAcc) - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - return test{ - acc: acc, - old: oldAcc, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, oldb) - assert.Equals(t, newval, b) - assert.Equals(t, bucket, accountTable) - assert.Equals(t, []byte(acc.ID), key) - return []byte("foo"), true, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if err := tc.acc.save(tc.db, tc.old); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - assert.Nil(t, tc.err) - } - }) - } -} - -func TestAccountSaveNew(t *testing.T) { +func TestKeyToID(t *testing.T) { type test struct { - acc *account - db nosql.DB + jwk *jose.JSONWebKey + exp string err *Error } tests := map[string]func(t *testing.T) test{ - "fail/keyToID-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - acc.Key.Key = "foo" - return test{ - acc: acc, - err: ServerInternalErr(errors.New("error generating jwk thumbprint: square/go-jose: unknown key type 'string'")), - } - }, - "fail/swap-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - kid, err := keyToID(acc.Key) - assert.FatalError(t, err) - return test{ - acc: acc, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte(kid)) - assert.Equals(t, old, nil) - assert.Equals(t, newval, []byte(acc.ID)) - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")), - } - }, - "fail/swap-false": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - kid, err := keyToID(acc.Key) - assert.FatalError(t, err) - return test{ - acc: acc, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte(kid)) - assert.Equals(t, old, nil) - assert.Equals(t, newval, []byte(acc.ID)) - return nil, false, nil - }, - }, - err: ServerInternalErr(errors.New("key-id to account-id index already exists")), - } - }, - "fail/save-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - kid, err := keyToID(acc.Key) - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - count := 0 - return test{ - acc: acc, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 0 { - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte(kid)) - assert.Equals(t, old, nil) - assert.Equals(t, newval, []byte(acc.ID)) - count++ - return nil, true, nil - } - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - assert.Equals(t, old, nil) - assert.Equals(t, newval, b) - return nil, false, errors.New("force") - }, - MDel: func(bucket, key []byte) error { - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte(kid)) - return nil - }, - }, - err: ServerInternalErr(errors.New("error storing account: force")), - } - }, - "ok": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - kid, err := keyToID(acc.Key) - assert.FatalError(t, err) - b, err := json.Marshal(acc) - assert.FatalError(t, err) - count := 0 - return test{ - acc: acc, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 0 { - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte(kid)) - assert.Equals(t, old, nil) - assert.Equals(t, newval, []byte(acc.ID)) - count++ - return nil, true, nil - } - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - assert.Equals(t, old, nil) - assert.Equals(t, newval, b) - return nil, true, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if err := tc.acc.saveNew(tc.db); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - assert.Nil(t, tc.err) - } - }) - } -} - -func TestAccountUpdate(t *testing.T) { - type test struct { - acc *account - contact []string - db nosql.DB - res []byte - err *Error - } - contact := []string{"foo", "bar"} - tests := map[string]func(t *testing.T) test{ - "fail/save-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - oldb, err := json.Marshal(acc) - assert.FatalError(t, err) - - _acc := *acc - clone := &_acc - clone.Contact = contact - b, err := json.Marshal(clone) + "fail/error-generating-thumbprint": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) + jwk.Key = "foo" return test{ - acc: acc, - contact: contact, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - assert.Equals(t, old, oldb) - assert.Equals(t, newval, b) - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error storing account: force")), + jwk: jwk, + err: NewErrorISE("error generating jwk thumbprint: square/go-jose: unknown key type 'string'"), } }, "ok": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - oldb, err := json.Marshal(acc) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - _acc := *acc - clone := &_acc - clone.Contact = contact - b, err := json.Marshal(clone) - assert.FatalError(t, err) - return test{ - acc: acc, - contact: contact, - res: b, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - assert.Equals(t, old, oldb) - assert.Equals(t, newval, b) - return nil, true, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - acc, err := tc.acc.update(tc.db, tc.contact) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - b, err := json.Marshal(acc) - assert.FatalError(t, err) - assert.Equals(t, b, tc.res) - } - } - }) - } -} - -func TestAccountDeactivate(t *testing.T) { - type test struct { - acc *account - db nosql.DB - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/save-error": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - oldb, err := json.Marshal(acc) - assert.FatalError(t, err) - - return test{ - acc: acc, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - assert.Equals(t, old, oldb) - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error storing account: force")), - } - }, - "ok": func(t *testing.T) test { - acc, err := newAcc() - assert.FatalError(t, err) - oldb, err := json.Marshal(acc) + kid, err := jwk.Thumbprint(crypto.SHA256) assert.FatalError(t, err) return test{ - acc: acc, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, key, []byte(acc.ID)) - assert.Equals(t, old, oldb) - return nil, true, nil - }, - }, + jwk: jwk, + exp: base64.RawURLEncoding.EncodeToString(kid), } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - acc, err := tc.acc.deactivate(tc.db) - if err != nil { + if id, err := KeyToID(tc.jwk); err != nil { if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) + switch k := err.(type) { + case *Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } } } else { if assert.Nil(t, tc.err) { - assert.Equals(t, acc.ID, tc.acc.ID) - assert.Equals(t, acc.Contact, tc.acc.Contact) - assert.Equals(t, acc.Status, StatusDeactivated) - assert.Equals(t, acc.Key.KeyID, tc.acc.Key.KeyID) - assert.Equals(t, acc.Created, tc.acc.Created) - - assert.True(t, acc.Deactivated.Before(time.Now().Add(time.Minute))) - assert.True(t, acc.Deactivated.After(time.Now().Add(-time.Minute))) + assert.Equals(t, id, tc.exp) } } }) } } -func TestNewAccount(t *testing.T) { - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - kid, err := keyToID(jwk) - assert.FatalError(t, err) - ops := AccountOptions{ - Key: jwk, - Contact: []string{"foo", "bar"}, - } +func TestAccount_IsValid(t *testing.T) { type test struct { - ops AccountOptions - db nosql.DB - err *Error - id *string + acc *Account + exp bool } - tests := map[string]func(t *testing.T) test{ - "fail/store-error": func(t *testing.T) test { - return test{ - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")), - } - }, - "ok": func(t *testing.T) test { - var _id string - id := &_id - count := 0 - return test{ - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - switch count { - case 0: - assert.Equals(t, bucket, accountByKeyIDTable) - assert.Equals(t, key, []byte(kid)) - case 1: - assert.Equals(t, bucket, accountTable) - *id = string(key) - } - count++ - return nil, true, nil - }, - }, - id: id, - } - }, + tests := map[string]test{ + "valid": {acc: &Account{Status: StatusValid}, exp: true}, + "invalid": {acc: &Account{Status: StatusInvalid}, exp: false}, } - for name, run := range tests { - tc := run(t) + for name, tc := range tests { t.Run(name, func(t *testing.T) { - acc, err := newAccount(tc.db, tc.ops) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, acc.ID, *tc.id) - assert.Equals(t, acc.Status, StatusValid) - assert.Equals(t, acc.Contact, ops.Contact) - assert.Equals(t, acc.Key.KeyID, ops.Key.KeyID) - - assert.True(t, acc.Deactivated.IsZero()) - - assert.True(t, acc.Created.Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, acc.Created.After(time.Now().UTC().Add(-1*time.Minute))) - } - } + assert.Equals(t, tc.acc.IsValid(), tc.exp) }) } } -*/ diff --git a/acme/authorization.go b/acme/authorization.go index 62bc4637..4d5c42c8 100644 --- a/acme/authorization.go +++ b/acme/authorization.go @@ -57,6 +57,7 @@ func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error { return nil } az.Status = StatusValid + az.Error = nil default: return NewErrorISE("unrecognized authorization status: %s", az.Status) } diff --git a/acme/authorization_test.go b/acme/authorization_test.go new file mode 100644 index 00000000..00b35b99 --- /dev/null +++ b/acme/authorization_test.go @@ -0,0 +1,150 @@ +package acme + +import ( + "context" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" +) + +func TestAuthorization_UpdateStatus(t *testing.T) { + type test struct { + az *Authorization + err *Error + db DB + } + tests := map[string]func(t *testing.T) test{ + "ok/already-invalid": func(t *testing.T) test { + az := &Authorization{ + Status: StatusInvalid, + } + return test{ + az: az, + } + }, + "ok/already-valid": func(t *testing.T) test { + az := &Authorization{ + Status: StatusInvalid, + } + return test{ + az: az, + } + }, + "fail/error-unexpected-status": func(t *testing.T) test { + az := &Authorization{ + Status: "foo", + } + return test{ + az: az, + err: NewErrorISE("unrecognized authorization status: %s", az.Status), + } + }, + "ok/expired": func(t *testing.T) test { + now := clock.Now() + az := &Authorization{ + ID: "azID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(-5 * time.Minute), + } + return test{ + az: az, + db: &MockDB{ + MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error { + assert.Equals(t, updaz.ID, az.ID) + assert.Equals(t, updaz.AccountID, az.AccountID) + assert.Equals(t, updaz.Status, StatusInvalid) + assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt) + return nil + }, + }, + } + }, + "fail/db.UpdateAuthorization-error": func(t *testing.T) test { + now := clock.Now() + az := &Authorization{ + ID: "azID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(-5 * time.Minute), + } + return test{ + az: az, + db: &MockDB{ + MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error { + assert.Equals(t, updaz.ID, az.ID) + assert.Equals(t, updaz.AccountID, az.AccountID) + assert.Equals(t, updaz.Status, StatusInvalid) + assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt) + return errors.New("force") + }, + }, + err: NewErrorISE("error updating authorization: force"), + } + }, + "ok/no-valid-challenges": func(t *testing.T) test { + now := clock.Now() + az := &Authorization{ + ID: "azID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(5 * time.Minute), + Challenges: []*Challenge{ + {Status: StatusPending}, {Status: StatusPending}, {Status: StatusPending}, + }, + } + return test{ + az: az, + } + }, + "ok/valid": func(t *testing.T) test { + now := clock.Now() + az := &Authorization{ + ID: "azID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(5 * time.Minute), + Challenges: []*Challenge{ + {Status: StatusPending}, {Status: StatusPending}, {Status: StatusValid}, + }, + } + return test{ + az: az, + db: &MockDB{ + MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error { + assert.Equals(t, updaz.ID, az.ID) + assert.Equals(t, updaz.AccountID, az.AccountID) + assert.Equals(t, updaz.Status, StatusValid) + assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt) + assert.Equals(t, updaz.Error, nil) + return nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if err := tc.az.UpdateStatus(context.Background(), tc.db); err != nil { + if assert.NotNil(t, tc.err) { + switch k := err.(type) { + case *Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } + } + } else { + assert.Nil(t, tc.err) + } + }) + + } +} diff --git a/acme/authz_test.go b/acme/authz_test.go deleted file mode 100644 index 206921c6..00000000 --- a/acme/authz_test.go +++ /dev/null @@ -1,824 +0,0 @@ -package acme - -/* -func newAz() (*Authorization, error) { - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - } - return newAuthz(mockdb, "1234", Identifier{ - Type: "dns", Value: "acme.example.com", - }) -} - -func TestGetAuthz(t *testing.T) { - type test struct { - id string - db nosql.DB - az authz - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/not-found": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - return test{ - az: az, - id: az.getID(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - }, - err: MalformedErr(errors.Errorf("authz %s not found: not found", az.getID())), - } - }, - "fail/db-error": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - return test{ - az: az, - id: az.getID(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.Errorf("error loading authz %s: force", az.getID())), - } - }, - "fail/unmarshal-error": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Identifier.Type = "foo" - b, err := json.Marshal(az) - assert.FatalError(t, err) - return test{ - az: az, - id: az.getID(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, authzTable) - assert.Equals(t, key, []byte(az.getID())) - return b, nil - }, - }, - err: ServerInternalErr(errors.New("unexpected authz type foo")), - } - }, - "ok": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - b, err := json.Marshal(az) - assert.FatalError(t, err) - return test{ - az: az, - id: az.getID(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, authzTable) - assert.Equals(t, key, []byte(az.getID())) - return b, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if az, err := getAuthz(tc.db, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.az.getID(), az.getID()) - assert.Equals(t, tc.az.getAccountID(), az.getAccountID()) - assert.Equals(t, tc.az.getStatus(), az.getStatus()) - assert.Equals(t, tc.az.getIdentifier(), az.getIdentifier()) - assert.Equals(t, tc.az.getCreated(), az.getCreated()) - assert.Equals(t, tc.az.getExpiry(), az.getExpiry()) - assert.Equals(t, tc.az.getChallenges(), az.getChallenges()) - } - } - }) - } -} - -func TestAuthzClone(t *testing.T) { - az, err := newAz() - assert.FatalError(t, err) - - clone := az.clone() - - assert.Equals(t, clone.getID(), az.getID()) - assert.Equals(t, clone.getAccountID(), az.getAccountID()) - assert.Equals(t, clone.getStatus(), az.getStatus()) - assert.Equals(t, clone.getIdentifier(), az.getIdentifier()) - assert.Equals(t, clone.getExpiry(), az.getExpiry()) - assert.Equals(t, clone.getCreated(), az.getCreated()) - assert.Equals(t, clone.getChallenges(), az.getChallenges()) - - clone.Status = StatusValid - - assert.NotEquals(t, clone.getStatus(), az.getStatus()) -} - -func TestNewAuthz(t *testing.T) { - iden := Identifier{ - Type: "dns", Value: "acme.example.com", - } - accID := "1234" - type test struct { - iden Identifier - db nosql.DB - err *Error - resChs *([]string) - } - tests := map[string]func(t *testing.T) test{ - "fail/unexpected-type": func(t *testing.T) test { - return test{ - iden: Identifier{Type: "foo", Value: "acme.example.com"}, - err: MalformedErr(errors.New("unexpected authz type foo")), - } - }, - "fail/new-http-chall-error": func(t *testing.T) test { - return test{ - iden: iden, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error creating http challenge: error saving acme challenge: force")), - } - }, - "fail/new-tls-alpn-chall-error": func(t *testing.T) test { - count := 0 - return test{ - iden: iden, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 1 { - return nil, false, errors.New("force") - } - count++ - return nil, true, nil - }, - }, - err: ServerInternalErr(errors.New("error creating alpn challenge: error saving acme challenge: force")), - } - }, - "fail/new-dns-chall-error": func(t *testing.T) test { - count := 0 - return test{ - iden: iden, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 2 { - return nil, false, errors.New("force") - } - count++ - return nil, true, nil - }, - }, - err: ServerInternalErr(errors.New("error creating dns challenge: error saving acme challenge: force")), - } - }, - "fail/save-authz-error": func(t *testing.T) test { - count := 0 - return test{ - iden: iden, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 3 { - return nil, false, errors.New("force") - } - count++ - return nil, true, nil - }, - }, - err: ServerInternalErr(errors.New("error storing authz: force")), - } - }, - "ok": func(t *testing.T) test { - chs := &([]string{}) - count := 0 - return test{ - iden: iden, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 3 { - assert.Equals(t, bucket, authzTable) - assert.Equals(t, old, nil) - - az, err := unmarshalAuthz(newval) - assert.FatalError(t, err) - - assert.Equals(t, az.getID(), string(key)) - assert.Equals(t, az.getAccountID(), accID) - assert.Equals(t, az.getStatus(), StatusPending) - assert.Equals(t, az.getIdentifier(), iden) - assert.Equals(t, az.getWildcard(), false) - - *chs = az.getChallenges() - - assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) - - expiry := az.getCreated().Add(defaultExpiryDuration) - assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute))) - assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute))) - } - count++ - return nil, true, nil - }, - }, - resChs: chs, - } - }, - "ok/wildcard": func(t *testing.T) test { - chs := &([]string{}) - count := 0 - _iden := Identifier{Type: "dns", Value: "*.acme.example.com"} - return test{ - iden: _iden, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 1 { - assert.Equals(t, bucket, authzTable) - assert.Equals(t, old, nil) - - az, err := unmarshalAuthz(newval) - assert.FatalError(t, err) - - assert.Equals(t, az.getID(), string(key)) - assert.Equals(t, az.getAccountID(), accID) - assert.Equals(t, az.getStatus(), StatusPending) - assert.Equals(t, az.getIdentifier(), iden) - assert.Equals(t, az.getWildcard(), true) - - *chs = az.getChallenges() - // Verify that we only have 1 challenge instead of 2. - assert.True(t, len(*chs) == 1) - - assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) - - expiry := az.getCreated().Add(defaultExpiryDuration) - assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute))) - assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute))) - } - count++ - return nil, true, nil - }, - }, - resChs: chs, - } - }, - } - for name, run := range tests { - tc := run(t) - t.Run(name, func(t *testing.T) { - az, err := newAuthz(tc.db, accID, tc.iden) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, az.getAccountID(), accID) - assert.Equals(t, az.getType(), "dns") - assert.Equals(t, az.getStatus(), StatusPending) - - assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) - - expiry := az.getCreated().Add(defaultExpiryDuration) - assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute))) - assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute))) - - assert.Equals(t, az.getChallenges(), *(tc.resChs)) - - if strings.HasPrefix(tc.iden.Value, "*.") { - assert.True(t, az.getWildcard()) - assert.Equals(t, az.getIdentifier().Value, strings.TrimPrefix(tc.iden.Value, "*.")) - } else { - assert.False(t, az.getWildcard()) - assert.Equals(t, az.getIdentifier().Value, tc.iden.Value) - } - - assert.True(t, az.getID() != "") - } - } - }) - } -} - -func TestAuthzToACME(t *testing.T) { - dir := newDirectory("ca.smallstep.com", "acme") - - var ( - ch1, ch2 challenge - ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{}) - err error - ) - - count := 0 - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 0 { - *ch1Bytes = newval - ch1, err = unmarshalChallenge(newval) - assert.FatalError(t, err) - } else if count == 1 { - *ch2Bytes = newval - ch2, err = unmarshalChallenge(newval) - assert.FatalError(t, err) - } - count++ - return []byte("foo"), true, nil - }, - } - iden := Identifier{ - Type: "dns", Value: "acme.example.com", - } - az, err := newAuthz(mockdb, "1234", iden) - assert.FatalError(t, err) - - prov := newProv() - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") - - type test struct { - db nosql.DB - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/getChallenge1-error": func(t *testing.T) test { - return test{ - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error loading challenge")), - } - }, - "fail/getChallenge2-error": func(t *testing.T) test { - count := 0 - return test{ - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - if count == 1 { - return nil, errors.New("force") - } - count++ - return *ch1Bytes, nil - }, - }, - err: ServerInternalErr(errors.New("error loading challenge")), - } - }, - "ok": func(t *testing.T) test { - count := 0 - return test{ - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - if count == 0 { - count++ - return *ch1Bytes, nil - } - return *ch2Bytes, nil - }, - }, - } - }, - } - for name, run := range tests { - tc := run(t) - t.Run(name, func(t *testing.T) { - acmeAz, err := az.toACME(ctx, tc.db, dir) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, acmeAz.ID, az.getID()) - assert.Equals(t, acmeAz.Identifier, iden) - assert.Equals(t, acmeAz.Status, StatusPending) - - acmeCh1, err := ch1.toACME(ctx, nil, dir) - assert.FatalError(t, err) - acmeCh2, err := ch2.toACME(ctx, nil, dir) - assert.FatalError(t, err) - - assert.Equals(t, acmeAz.Challenges[0], acmeCh1) - assert.Equals(t, acmeAz.Challenges[1], acmeCh2) - - expiry, err := time.Parse(time.RFC3339, acmeAz.Expires) - assert.FatalError(t, err) - assert.Equals(t, expiry.String(), az.getExpiry().String()) - } - } - }) - } -} - -func TestAuthzSave(t *testing.T) { - type test struct { - az, old authz - db nosql.DB - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/old-nil/swap-error": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - return test{ - az: az, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error storing authz: force")), - } - }, - "fail/old-nil/swap-false": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - return test{ - az: az, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), false, nil - }, - }, - err: ServerInternalErr(errors.New("error storing authz; value has changed since last read")), - } - }, - "ok/old-nil": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - b, err := json.Marshal(az) - assert.FatalError(t, err) - return test{ - az: az, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, nil) - assert.Equals(t, b, newval) - assert.Equals(t, bucket, authzTable) - assert.Equals(t, []byte(az.getID()), key) - return nil, true, nil - }, - }, - } - }, - "ok/old-not-nil": func(t *testing.T) test { - oldAz, err := newAz() - assert.FatalError(t, err) - az, err := newAz() - assert.FatalError(t, err) - - oldb, err := json.Marshal(oldAz) - assert.FatalError(t, err) - b, err := json.Marshal(az) - assert.FatalError(t, err) - return test{ - az: az, - old: oldAz, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, oldb) - assert.Equals(t, b, newval) - assert.Equals(t, bucket, authzTable) - assert.Equals(t, []byte(az.getID()), key) - return []byte("foo"), true, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if err := tc.az.save(tc.db, tc.old); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - assert.Nil(t, tc.err) - } - }) - } -} - -func TestAuthzUnmarshal(t *testing.T) { - type test struct { - az authz - azb []byte - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/nil": func(t *testing.T) test { - return test{ - azb: nil, - err: ServerInternalErr(errors.New("error unmarshaling authz type: unexpected end of JSON input")), - } - }, - "fail/unexpected-type": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Identifier.Type = "foo" - b, err := json.Marshal(az) - assert.FatalError(t, err) - return test{ - azb: b, - err: ServerInternalErr(errors.New("unexpected authz type foo")), - } - }, - "ok/dns": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - b, err := json.Marshal(az) - assert.FatalError(t, err) - return test{ - az: az, - azb: b, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if az, err := unmarshalAuthz(tc.azb); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.az.getID(), az.getID()) - assert.Equals(t, tc.az.getAccountID(), az.getAccountID()) - assert.Equals(t, tc.az.getStatus(), az.getStatus()) - assert.Equals(t, tc.az.getCreated(), az.getCreated()) - assert.Equals(t, tc.az.getExpiry(), az.getExpiry()) - assert.Equals(t, tc.az.getWildcard(), az.getWildcard()) - assert.Equals(t, tc.az.getChallenges(), az.getChallenges()) - } - } - }) - } -} - -func TestAuthzUpdateStatus(t *testing.T) { - type test struct { - az, res authz - err *Error - db nosql.DB - } - tests := map[string]func(t *testing.T) test{ - "fail/already-invalid": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Status = StatusInvalid - return test{ - az: az, - res: az, - } - }, - "fail/already-valid": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Status = StatusValid - return test{ - az: az, - res: az, - } - }, - "fail/unexpected-status": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Status = StatusReady - return test{ - az: az, - res: az, - err: ServerInternalErr(errors.New("unrecognized authz status: ready")), - } - }, - "fail/save-error": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute) - return test{ - az: az, - res: az, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error storing authz: force")), - } - }, - "ok/expired": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute) - - clone := az.clone() - clone.Error = MalformedErr(errors.New("authz has expired")) - clone.Status = StatusInvalid - return test{ - az: az, - res: clone.parent(), - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, true, nil - }, - }, - } - }, - "fail/get-challenge-error": func(t *testing.T) test { - az, err := newAz() - assert.FatalError(t, err) - - return test{ - az: az, - res: az, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error loading challenge")), - } - }, - "ok/valid": func(t *testing.T) test { - var ( - ch3 challenge - ch2Bytes = &([]byte{}) - ch1Bytes = &([]byte{}) - err error - ) - - count := 0 - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 0 { - *ch1Bytes = newval - } else if count == 1 { - *ch2Bytes = newval - } else if count == 2 { - ch3, err = unmarshalChallenge(newval) - assert.FatalError(t, err) - } - count++ - return nil, true, nil - }, - } - iden := Identifier{ - Type: "dns", Value: "acme.example.com", - } - az, err := newAuthz(mockdb, "1234", iden) - assert.FatalError(t, err) - _az, ok := az.(*dnsAuthz) - assert.Fatal(t, ok) - _az.baseAuthz.Error = MalformedErr(nil) - - _ch, ok := ch3.(*dns01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Status = StatusValid - chb, err := json.Marshal(ch3) - - clone := az.clone() - clone.Status = StatusValid - clone.Error = nil - - count = 0 - return test{ - az: az, - res: clone.parent(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - if count == 0 { - count++ - return *ch1Bytes, nil - } - if count == 1 { - count++ - return *ch2Bytes, nil - } - count++ - return chb, nil - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, true, nil - }, - }, - } - }, - "ok/still-pending": func(t *testing.T) test { - var ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{}) - - count := 0 - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - if count == 0 { - *ch1Bytes = newval - } else if count == 1 { - *ch2Bytes = newval - } - count++ - return nil, true, nil - }, - } - iden := Identifier{ - Type: "dns", Value: "acme.example.com", - } - az, err := newAuthz(mockdb, "1234", iden) - assert.FatalError(t, err) - - count = 0 - return test{ - az: az, - res: az, - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - if count == 0 { - count++ - return *ch1Bytes, nil - } - count++ - return *ch2Bytes, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - az, err := tc.az.updateStatus(tc.db) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - expB, err := json.Marshal(tc.res) - assert.FatalError(t, err) - b, err := json.Marshal(az) - assert.FatalError(t, err) - assert.Equals(t, expB, b) - } - } - }) - } -} -*/ diff --git a/acme/order.go b/acme/order.go index f62e3354..400a4ce2 100644 --- a/acme/order.go +++ b/acme/order.go @@ -81,10 +81,10 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error { for _, azID := range o.AuthorizationIDs { az, err := db.GetAuthorization(ctx, azID) if err != nil { - return err + return WrapErrorISE(err, "error getting authorization ID %s", azID) } if err = az.UpdateStatus(ctx, db); err != nil { - return err + return WrapErrorISE(err, "error updating authorization ID %s", azID) } st := az.Status count[st]++ @@ -107,7 +107,10 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error { default: return NewErrorISE("unrecognized order status: %s", o.Status) } - return db.UpdateOrder(ctx, o) + if err := db.UpdateOrder(ctx, o); err != nil { + return WrapErrorISE(err, "error updating order") + } + return nil } // Finalize signs a certificate if the necessary conditions for Order completion diff --git a/acme/order_test.go b/acme/order_test.go index 5bd21fdb..d86afeb5 100644 --- a/acme/order_test.go +++ b/acme/order_test.go @@ -1,5 +1,261 @@ package acme +import ( + "context" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" +) + +func TestOrder_UpdateStatus(t *testing.T) { + type test struct { + o *Order + err *Error + db DB + } + tests := map[string]func(t *testing.T) test{ + "ok/already-invalid": func(t *testing.T) test { + o := &Order{ + Status: StatusInvalid, + } + return test{ + o: o, + } + }, + "ok/already-valid": func(t *testing.T) test { + o := &Order{ + Status: StatusInvalid, + } + return test{ + o: o, + } + }, + "fail/error-unexpected-status": func(t *testing.T) test { + o := &Order{ + Status: "foo", + } + return test{ + o: o, + err: NewErrorISE("unrecognized order status: %s", o.Status), + } + }, + "ok/ready-expired": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusReady, + ExpiresAt: now.Add(-5 * time.Minute), + } + return test{ + o: o, + db: &MockDB{ + MockUpdateOrder: func(ctx context.Context, updo *Order) error { + assert.Equals(t, updo.ID, o.ID) + assert.Equals(t, updo.AccountID, o.AccountID) + assert.Equals(t, updo.Status, StatusInvalid) + assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) + return nil + }, + }, + } + }, + "fail/ready-expired-db.UpdateOrder-error": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusReady, + ExpiresAt: now.Add(-5 * time.Minute), + } + return test{ + o: o, + db: &MockDB{ + MockUpdateOrder: func(ctx context.Context, updo *Order) error { + assert.Equals(t, updo.ID, o.ID) + assert.Equals(t, updo.AccountID, o.AccountID) + assert.Equals(t, updo.Status, StatusInvalid) + assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) + return errors.New("force") + }, + }, + err: NewErrorISE("error updating order: force"), + } + }, + "ok/pending-expired": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(-5 * time.Minute), + } + return test{ + o: o, + db: &MockDB{ + MockUpdateOrder: func(ctx context.Context, updo *Order) error { + assert.Equals(t, updo.ID, o.ID) + assert.Equals(t, updo.AccountID, o.AccountID) + assert.Equals(t, updo.Status, StatusInvalid) + assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) + + err := NewError(ErrorMalformedType, "order has expired") + assert.HasPrefix(t, updo.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updo.Error.Type, err.Type) + assert.Equals(t, updo.Error.Detail, err.Detail) + assert.Equals(t, updo.Error.Status, err.Status) + assert.Equals(t, updo.Error.Detail, err.Detail) + return nil + }, + }, + } + }, + "ok/invalid": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + } + az1 := &Authorization{ + ID: "a", + Status: StatusValid, + } + az2 := &Authorization{ + ID: "b", + Status: StatusInvalid, + } + + return test{ + o: o, + db: &MockDB{ + MockUpdateOrder: func(ctx context.Context, updo *Order) error { + assert.Equals(t, updo.ID, o.ID) + assert.Equals(t, updo.AccountID, o.AccountID) + assert.Equals(t, updo.Status, StatusInvalid) + assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) + return nil + }, + MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { + switch id { + case az1.ID: + return az1, nil + case az2.ID: + return az2, nil + default: + assert.FatalError(t, errors.Errorf("unexpected authz key %s", id)) + return nil, errors.New("force") + } + }, + }, + } + }, + "ok/still-pending": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + } + az1 := &Authorization{ + ID: "a", + Status: StatusValid, + } + az2 := &Authorization{ + ID: "b", + Status: StatusPending, + } + + return test{ + o: o, + db: &MockDB{ + MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { + switch id { + case az1.ID: + return az1, nil + case az2.ID: + return az2, nil + default: + assert.FatalError(t, errors.Errorf("unexpected authz key %s", id)) + return nil, errors.New("force") + } + }, + }, + } + }, + "ok/valid": func(t *testing.T) test { + now := clock.Now() + o := &Order{ + ID: "oID", + AccountID: "accID", + Status: StatusPending, + ExpiresAt: now.Add(5 * time.Minute), + AuthorizationIDs: []string{"a", "b"}, + } + az1 := &Authorization{ + ID: "a", + Status: StatusValid, + } + az2 := &Authorization{ + ID: "b", + Status: StatusValid, + } + + return test{ + o: o, + db: &MockDB{ + MockUpdateOrder: func(ctx context.Context, updo *Order) error { + assert.Equals(t, updo.ID, o.ID) + assert.Equals(t, updo.AccountID, o.AccountID) + assert.Equals(t, updo.Status, StatusReady) + assert.Equals(t, updo.ExpiresAt, o.ExpiresAt) + return nil + }, + MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) { + switch id { + case az1.ID: + return az1, nil + case az2.ID: + return az2, nil + default: + assert.FatalError(t, errors.Errorf("unexpected authz key %s", id)) + return nil, errors.New("force") + } + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if err := tc.o.UpdateStatus(context.Background(), tc.db); err != nil { + if assert.NotNil(t, tc.err) { + switch k := err.(type) { + case *Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } + } + } else { + assert.Nil(t, tc.err) + } + }) + + } +} + /* var certDuration = 6 * time.Hour