diff --git a/acme/account_test.go b/acme/account_test.go index 0008551a..2e072af5 100644 --- a/acme/account_test.go +++ b/acme/account_test.go @@ -251,337 +251,6 @@ func TestGetAccountByKeyID(t *testing.T) { } } -func Test_getOrderIDsByAccount(t *testing.T) { - type test struct { - id string - db nosql.DB - res []string - err *Error - } - tests := map[string]func(t *testing.T) test{ - "ok/not-found": func(t *testing.T) test { - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - }, - res: []string{}, - } - }, - "fail/db-error": func(t *testing.T) test { - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error loading orderIDs for account foo: force")), - } - }, - "fail/unmarshal-error": func(t *testing.T) test { - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte("foo")) - return nil, nil - }, - }, - err: ServerInternalErr(errors.New("error unmarshaling orderIDs for account foo: unexpected end of JSON input")), - } - }, - "fail/error-loading-order-from-order-IDs": func(t *testing.T) test { - oids := []string{"o1", "o2", "o3"} - boids, err := json.Marshal(oids) - assert.FatalError(t, err) - dbHit := 0 - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - dbHit++ - switch dbHit { - case 1: - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte("foo")) - return boids, nil - case 2: - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte("o1")) - return nil, errors.New("force") - default: - assert.FatalError(t, errors.New("should not be here")) - return nil, nil - } - }, - }, - err: ServerInternalErr(errors.New("error loading order o1 for account foo: error loading order o1: force")), - } - }, - "fail/error-updating-order-from-order-IDs": func(t *testing.T) test { - oids := []string{"o1", "o2", "o3"} - boids, err := json.Marshal(oids) - assert.FatalError(t, err) - - o, err := newO() - assert.FatalError(t, err) - bo, err := json.Marshal(o) - assert.FatalError(t, err) - - dbHit := 0 - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - dbHit++ - switch dbHit { - case 1: - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte("foo")) - return boids, nil - case 2: - assert.Equals(t, bucket, orderTable) - assert.Equals(t, key, []byte("o1")) - return bo, nil - case 3: - assert.Equals(t, bucket, authzTable) - assert.Equals(t, key, []byte(o.Authorizations[0])) - return nil, errors.New("force") - default: - assert.FatalError(t, errors.New("should not be here")) - return nil, nil - } - }, - }, - err: ServerInternalErr(errors.Errorf("error updating order o1 for account foo: error loading authz %s: force", o.Authorizations[0])), - } - }, - "ok/no-change-to-pending-orders": func(t *testing.T) test { - oids := []string{"o1", "o2", "o3"} - boids, err := json.Marshal(oids) - assert.FatalError(t, err) - - o, err := newO() - assert.FatalError(t, err) - bo, err := json.Marshal(o) - assert.FatalError(t, err) - - az, err := newAz() - assert.FatalError(t, err) - baz, err := json.Marshal(az) - assert.FatalError(t, err) - - ch, err := newDNSCh() - assert.FatalError(t, err) - bch, err := json.Marshal(ch) - assert.FatalError(t, err) - - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - switch string(bucket) { - case string(ordersByAccountIDTable): - assert.Equals(t, key, []byte("foo")) - return boids, nil - case string(orderTable): - return bo, nil - case string(authzTable): - return baz, nil - case string(challengeTable): - return bch, nil - default: - assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) - return nil, nil - } - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("should not be attempting to store anything") - }, - }, - res: oids, - } - }, - "fail/error-storing-new-oids": func(t *testing.T) test { - oids := []string{"o1", "o2", "o3"} - boids, err := json.Marshal(oids) - assert.FatalError(t, err) - - o, err := newO() - assert.FatalError(t, err) - bo, err := json.Marshal(o) - assert.FatalError(t, err) - - invalidOrder, err := newO() - assert.FatalError(t, err) - invalidOrder.Status = StatusInvalid - binvalidOrder, err := json.Marshal(invalidOrder) - assert.FatalError(t, err) - - az, err := newAz() - assert.FatalError(t, err) - baz, err := json.Marshal(az) - assert.FatalError(t, err) - - ch, err := newDNSCh() - assert.FatalError(t, err) - bch, err := json.Marshal(ch) - assert.FatalError(t, err) - - dbGetOrder := 0 - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - switch string(bucket) { - case string(ordersByAccountIDTable): - assert.Equals(t, key, []byte("foo")) - return boids, nil - case string(orderTable): - dbGetOrder++ - if dbGetOrder == 1 { - return binvalidOrder, nil - } - return bo, nil - case string(authzTable): - return baz, nil - case string(challengeTable): - return bch, nil - default: - assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) - return nil, nil - } - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte("foo")) - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error storing orderIDs as part of getOrderIDsByAccount logic: len(orderIDs) = 2: error storing order IDs for account foo: force")), - } - }, - "ok": func(t *testing.T) test { - oids := []string{"o1", "o2", "o3", "o4"} - boids, err := json.Marshal(oids) - assert.FatalError(t, err) - - o, err := newO() - assert.FatalError(t, err) - bo, err := json.Marshal(o) - assert.FatalError(t, err) - - invalidOrder, err := newO() - assert.FatalError(t, err) - invalidOrder.Status = StatusInvalid - binvalidOrder, err := json.Marshal(invalidOrder) - assert.FatalError(t, err) - - az, err := newAz() - assert.FatalError(t, err) - baz, err := json.Marshal(az) - assert.FatalError(t, err) - - ch, err := newDNSCh() - assert.FatalError(t, err) - bch, err := json.Marshal(ch) - assert.FatalError(t, err) - - dbGetOrder := 0 - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - switch string(bucket) { - case string(ordersByAccountIDTable): - assert.Equals(t, key, []byte("foo")) - return boids, nil - case string(orderTable): - dbGetOrder++ - if dbGetOrder == 1 || dbGetOrder == 3 { - return binvalidOrder, nil - } - return bo, nil - case string(authzTable): - return baz, nil - case string(challengeTable): - return bch, nil - default: - assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) - return nil, nil - } - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte("foo")) - return nil, true, nil - }, - }, - res: []string{"o2", "o4"}, - } - }, - "ok/no-pending-orders": func(t *testing.T) test { - oids := []string{"o1"} - boids, err := json.Marshal(oids) - assert.FatalError(t, err) - - invalidOrder, err := newO() - assert.FatalError(t, err) - invalidOrder.Status = StatusInvalid - binvalidOrder, err := json.Marshal(invalidOrder) - assert.FatalError(t, err) - - return test{ - id: "foo", - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - switch string(bucket) { - case string(ordersByAccountIDTable): - assert.Equals(t, key, []byte("foo")) - return boids, nil - case string(orderTable): - return binvalidOrder, nil - default: - assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) - return nil, nil - } - }, - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, ordersByAccountIDTable) - assert.Equals(t, key, []byte("foo")) - assert.Equals(t, old, boids) - assert.Nil(t, newval) - return nil, true, nil - }, - }, - res: []string{}, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if oids, err := getOrderIDsByAccount(tc.db, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.res, oids) - } - } - }) - } -} - func TestAccountToACME(t *testing.T) { dir := newDirectory("ca.smallstep.com", "acme") prov := newProv() diff --git a/acme/authority.go b/acme/authority.go index d1bb0aaf..0f5f2c9f 100644 --- a/acme/authority.go +++ b/acme/authority.go @@ -233,8 +233,11 @@ func (a *Authority) GetOrder(ctx context.Context, accID, orderID string) (*Order // GetOrdersByAccount returns the list of order urls owned by the account. func (a *Authority) GetOrdersByAccount(ctx context.Context, id string) ([]string, error) { + ordersByAccountMux.Lock() + defer ordersByAccountMux.Unlock() + var oiba = orderIDsByAccount{} - oids, err := oiba.getOrderIDsByAccount(a.db, id, false) + oids, err := oiba.unsafeGetOrderIDsByAccount(a.db, id) if err != nil { return nil, err } diff --git a/acme/order.go b/acme/order.go index ef5345e4..b0b0ec54 100644 --- a/acme/order.go +++ b/acme/order.go @@ -125,12 +125,15 @@ func newOrder(db nosql.DB, ops OrderOptions) (*order, error) { type orderIDsByAccount struct{} +// addOrderID adds an order ID to a users index of in progress order IDs. +// This method will also cull any orders that are no longer in the `pending` +// state from the index before returning it. func (oiba orderIDsByAccount) addOrderID(db nosql.DB, accID string, oid string) ([]string, error) { ordersByAccountMux.Lock() defer ordersByAccountMux.Unlock() // Update the "order IDs by account ID" index - oids, err := oiba.getOrderIDsByAccount(db, accID, true) + oids, err := oiba.unsafeGetOrderIDsByAccount(db, accID) if err != nil { return nil, err } @@ -143,15 +146,9 @@ func (oiba orderIDsByAccount) addOrderID(db nosql.DB, accID string, oid string) return newOids, nil } -// getOrderIDsByAccount retrieves a list of Order IDs that were created by the +// unsafeGetOrderIDsByAccount retrieves a list of Order IDs that were created by the // account. -func (oiba orderIDsByAccount) getOrderIDsByAccount(db nosql.DB, accID string, alreadyLocked bool) ([]string, error) { - if !alreadyLocked { - ordersByAccountMux.Lock() - - defer ordersByAccountMux.Unlock() - } - +func (oiba orderIDsByAccount) unsafeGetOrderIDsByAccount(db nosql.DB, accID string) ([]string, error) { b, err := db.Get(ordersByAccountIDTable, []byte(accID)) if err != nil { if nosql.IsErrNotFound(err) { diff --git a/acme/order_test.go b/acme/order_test.go index 785b24c4..e6a8f057 100644 --- a/acme/order_test.go +++ b/acme/order_test.go @@ -1403,3 +1403,335 @@ func TestOrderFinalize(t *testing.T) { }) } } + +func Test_getOrderIDsByAccount(t *testing.T) { + type test struct { + id string + db nosql.DB + res []string + err *Error + } + tests := map[string]func(t *testing.T) test{ + "ok/not-found": func(t *testing.T) test { + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, database.ErrNotFound + }, + }, + res: []string{}, + } + }, + "fail/db-error": func(t *testing.T) test { + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error loading orderIDs for account foo: force")), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte("foo")) + return nil, nil + }, + }, + err: ServerInternalErr(errors.New("error unmarshaling orderIDs for account foo: unexpected end of JSON input")), + } + }, + "fail/error-loading-order-from-order-IDs": func(t *testing.T) test { + oids := []string{"o1", "o2", "o3"} + boids, err := json.Marshal(oids) + assert.FatalError(t, err) + dbHit := 0 + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + dbHit++ + switch dbHit { + case 1: + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte("foo")) + return boids, nil + case 2: + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte("o1")) + return nil, errors.New("force") + default: + assert.FatalError(t, errors.New("should not be here")) + return nil, nil + } + }, + }, + err: ServerInternalErr(errors.New("error loading order o1 for account foo: error loading order o1: force")), + } + }, + "fail/error-updating-order-from-order-IDs": func(t *testing.T) test { + oids := []string{"o1", "o2", "o3"} + boids, err := json.Marshal(oids) + assert.FatalError(t, err) + + o, err := newO() + assert.FatalError(t, err) + bo, err := json.Marshal(o) + assert.FatalError(t, err) + + dbHit := 0 + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + dbHit++ + switch dbHit { + case 1: + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte("foo")) + return boids, nil + case 2: + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte("o1")) + return bo, nil + case 3: + assert.Equals(t, bucket, authzTable) + assert.Equals(t, key, []byte(o.Authorizations[0])) + return nil, errors.New("force") + default: + assert.FatalError(t, errors.New("should not be here")) + return nil, nil + } + }, + }, + err: ServerInternalErr(errors.Errorf("error updating order o1 for account foo: error loading authz %s: force", o.Authorizations[0])), + } + }, + "ok/no-change-to-pending-orders": func(t *testing.T) test { + oids := []string{"o1", "o2", "o3"} + boids, err := json.Marshal(oids) + assert.FatalError(t, err) + + o, err := newO() + assert.FatalError(t, err) + bo, err := json.Marshal(o) + assert.FatalError(t, err) + + az, err := newAz() + assert.FatalError(t, err) + baz, err := json.Marshal(az) + assert.FatalError(t, err) + + ch, err := newDNSCh() + assert.FatalError(t, err) + bch, err := json.Marshal(ch) + assert.FatalError(t, err) + + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte("foo")) + return boids, nil + case string(orderTable): + return bo, nil + case string(authzTable): + return baz, nil + case string(challengeTable): + return bch, nil + default: + assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) + return nil, nil + } + }, + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("should not be attempting to store anything") + }, + }, + res: oids, + } + }, + "fail/error-storing-new-oids": func(t *testing.T) test { + oids := []string{"o1", "o2", "o3"} + boids, err := json.Marshal(oids) + assert.FatalError(t, err) + + o, err := newO() + assert.FatalError(t, err) + bo, err := json.Marshal(o) + assert.FatalError(t, err) + + invalidOrder, err := newO() + assert.FatalError(t, err) + invalidOrder.Status = StatusInvalid + binvalidOrder, err := json.Marshal(invalidOrder) + assert.FatalError(t, err) + + az, err := newAz() + assert.FatalError(t, err) + baz, err := json.Marshal(az) + assert.FatalError(t, err) + + ch, err := newDNSCh() + assert.FatalError(t, err) + bch, err := json.Marshal(ch) + assert.FatalError(t, err) + + dbGetOrder := 0 + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte("foo")) + return boids, nil + case string(orderTable): + dbGetOrder++ + if dbGetOrder == 1 { + return binvalidOrder, nil + } + return bo, nil + case string(authzTable): + return baz, nil + case string(challengeTable): + return bch, nil + default: + assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) + return nil, nil + } + }, + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte("foo")) + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error storing orderIDs as part of getOrderIDsByAccount logic: len(orderIDs) = 2: error storing order IDs for account foo: force")), + } + }, + "ok": func(t *testing.T) test { + oids := []string{"o1", "o2", "o3", "o4"} + boids, err := json.Marshal(oids) + assert.FatalError(t, err) + + o, err := newO() + assert.FatalError(t, err) + bo, err := json.Marshal(o) + assert.FatalError(t, err) + + invalidOrder, err := newO() + assert.FatalError(t, err) + invalidOrder.Status = StatusInvalid + binvalidOrder, err := json.Marshal(invalidOrder) + assert.FatalError(t, err) + + az, err := newAz() + assert.FatalError(t, err) + baz, err := json.Marshal(az) + assert.FatalError(t, err) + + ch, err := newDNSCh() + assert.FatalError(t, err) + bch, err := json.Marshal(ch) + assert.FatalError(t, err) + + dbGetOrder := 0 + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte("foo")) + return boids, nil + case string(orderTable): + dbGetOrder++ + if dbGetOrder == 1 || dbGetOrder == 3 { + return binvalidOrder, nil + } + return bo, nil + case string(authzTable): + return baz, nil + case string(challengeTable): + return bch, nil + default: + assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) + return nil, nil + } + }, + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte("foo")) + return nil, true, nil + }, + }, + res: []string{"o2", "o4"}, + } + }, + "ok/no-pending-orders": func(t *testing.T) test { + oids := []string{"o1"} + boids, err := json.Marshal(oids) + assert.FatalError(t, err) + + invalidOrder, err := newO() + assert.FatalError(t, err) + invalidOrder.Status = StatusInvalid + binvalidOrder, err := json.Marshal(invalidOrder) + assert.FatalError(t, err) + + return test{ + id: "foo", + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte("foo")) + return boids, nil + case string(orderTable): + return binvalidOrder, nil + default: + assert.FatalError(t, errors.Errorf("did not expect query to table %s", bucket)) + return nil, nil + } + }, + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte("foo")) + assert.Equals(t, old, boids) + assert.Nil(t, newval) + return nil, true, nil + }, + }, + res: []string{}, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + var oiba = orderIDsByAccount{} + if oids, err := oiba.unsafeGetOrderIDsByAccount(tc.db, tc.id); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.res, oids) + } + } + }) + } +}