Merge branch 'smallstep_master' into extractable
commit
aa80bf9f07
@ -0,0 +1,9 @@
|
||||
blank_issues_enabled: true
|
||||
contact_links:
|
||||
- name: Ask on Discord
|
||||
url: https://discord.gg/7xgjhVAg6g
|
||||
about: You can ask for help here!
|
||||
- name: Want to contribute to step certificates?
|
||||
url: https://github.com/smallstep/certificates/blob/master/docs/CONTRIBUTING.md
|
||||
about: Be sure to read contributing guidelines!
|
||||
|
@ -1,14 +1,12 @@
|
||||
name: labeler
|
||||
name: Pull Request Labeler
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
pull_request_target
|
||||
|
||||
jobs:
|
||||
label:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@v3
|
||||
- uses: actions/labeler@v3.0.2
|
||||
with:
|
||||
repo-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
configuration-path: .github/needs-triage-labeler.yml
|
||||
|
||||
|
@ -0,0 +1,160 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"go.step.sm/linkedca"
|
||||
)
|
||||
|
||||
// CreateAdminRequest represents the body for a CreateAdmin request.
|
||||
type CreateAdminRequest struct {
|
||||
Subject string `json:"subject"`
|
||||
Provisioner string `json:"provisioner"`
|
||||
Type linkedca.Admin_Type `json:"type"`
|
||||
}
|
||||
|
||||
// Validate validates a new-admin request body.
|
||||
func (car *CreateAdminRequest) Validate() error {
|
||||
if car.Subject == "" {
|
||||
return admin.NewError(admin.ErrorBadRequestType, "subject cannot be empty")
|
||||
}
|
||||
if car.Provisioner == "" {
|
||||
return admin.NewError(admin.ErrorBadRequestType, "provisioner cannot be empty")
|
||||
}
|
||||
switch car.Type {
|
||||
case linkedca.Admin_SUPER_ADMIN, linkedca.Admin_ADMIN:
|
||||
default:
|
||||
return admin.NewError(admin.ErrorBadRequestType, "invalid value for admin type")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAdminsResponse for returning a list of admins.
|
||||
type GetAdminsResponse struct {
|
||||
Admins []*linkedca.Admin `json:"admins"`
|
||||
NextCursor string `json:"nextCursor"`
|
||||
}
|
||||
|
||||
// UpdateAdminRequest represents the body for a UpdateAdmin request.
|
||||
type UpdateAdminRequest struct {
|
||||
Type linkedca.Admin_Type `json:"type"`
|
||||
}
|
||||
|
||||
// Validate validates a new-admin request body.
|
||||
func (uar *UpdateAdminRequest) Validate() error {
|
||||
switch uar.Type {
|
||||
case linkedca.Admin_SUPER_ADMIN, linkedca.Admin_ADMIN:
|
||||
default:
|
||||
return admin.NewError(admin.ErrorBadRequestType, "invalid value for admin type")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteResponse is the resource for successful DELETE responses.
|
||||
type DeleteResponse struct {
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// GetAdmin returns the requested admin, or an error.
|
||||
func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
|
||||
adm, ok := h.auth.LoadAdminByID(id)
|
||||
if !ok {
|
||||
api.WriteError(w, admin.NewError(admin.ErrorNotFoundType,
|
||||
"admin %s not found", id))
|
||||
return
|
||||
}
|
||||
api.ProtoJSON(w, adm)
|
||||
}
|
||||
|
||||
// GetAdmins returns a segment of admins associated with the authority.
|
||||
func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
|
||||
cursor, limit, err := api.ParseCursor(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err,
|
||||
"error parsing cursor and limit from query params"))
|
||||
return
|
||||
}
|
||||
|
||||
admins, nextCursor, err := h.auth.GetAdmins(cursor, limit)
|
||||
if err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error retrieving paginated admins"))
|
||||
return
|
||||
}
|
||||
api.JSON(w, &GetAdminsResponse{
|
||||
Admins: admins,
|
||||
NextCursor: nextCursor,
|
||||
})
|
||||
}
|
||||
|
||||
// CreateAdmin creates a new admin.
|
||||
func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
var body CreateAdminRequest
|
||||
if err := api.ReadJSON(r.Body, &body); err != nil {
|
||||
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := body.Validate(); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
p, err := h.auth.LoadProvisionerByName(body.Provisioner)
|
||||
if err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner))
|
||||
return
|
||||
}
|
||||
adm := &linkedca.Admin{
|
||||
ProvisionerId: p.GetID(),
|
||||
Subject: body.Subject,
|
||||
Type: body.Type,
|
||||
}
|
||||
// Store to authority collection.
|
||||
if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error storing admin"))
|
||||
return
|
||||
}
|
||||
|
||||
api.ProtoJSONStatus(w, adm, http.StatusCreated)
|
||||
}
|
||||
|
||||
// DeleteAdmin deletes admin.
|
||||
func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
|
||||
if err := h.auth.RemoveAdmin(r.Context(), id); err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error deleting admin %s", id))
|
||||
return
|
||||
}
|
||||
|
||||
api.JSON(w, &DeleteResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
// UpdateAdmin updates an existing admin.
|
||||
func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
var body UpdateAdminRequest
|
||||
if err := api.ReadJSON(r.Body, &body); err != nil {
|
||||
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := body.Validate(); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
id := chi.URLParam(r, "id")
|
||||
|
||||
adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type})
|
||||
if err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error updating admin %s", id))
|
||||
return
|
||||
}
|
||||
|
||||
api.ProtoJSON(w, adm)
|
||||
}
|
@ -0,0 +1,41 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
)
|
||||
|
||||
// Handler is the ACME API request handler.
|
||||
type Handler struct {
|
||||
db admin.DB
|
||||
auth *authority.Authority
|
||||
}
|
||||
|
||||
// NewHandler returns a new Authority Config Handler.
|
||||
func NewHandler(auth *authority.Authority) api.RouterHandler {
|
||||
h := &Handler{db: auth.GetAdminDatabase(), auth: auth}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// Route traffic and implement the Router interface.
|
||||
func (h *Handler) Route(r api.Router) {
|
||||
authnz := func(next nextHTTP) nextHTTP {
|
||||
return h.extractAuthorizeTokenAdmin(h.requireAPIEnabled(next))
|
||||
}
|
||||
|
||||
// Provisioners
|
||||
r.MethodFunc("GET", "/provisioners/{name}", authnz(h.GetProvisioner))
|
||||
r.MethodFunc("GET", "/provisioners", authnz(h.GetProvisioners))
|
||||
r.MethodFunc("POST", "/provisioners", authnz(h.CreateProvisioner))
|
||||
r.MethodFunc("PUT", "/provisioners/{name}", authnz(h.UpdateProvisioner))
|
||||
r.MethodFunc("DELETE", "/provisioners/{name}", authnz(h.DeleteProvisioner))
|
||||
|
||||
// Admins
|
||||
r.MethodFunc("GET", "/admins/{id}", authnz(h.GetAdmin))
|
||||
r.MethodFunc("GET", "/admins", authnz(h.GetAdmins))
|
||||
r.MethodFunc("POST", "/admins", authnz(h.CreateAdmin))
|
||||
r.MethodFunc("PATCH", "/admins/{id}", authnz(h.UpdateAdmin))
|
||||
r.MethodFunc("DELETE", "/admins/{id}", authnz(h.DeleteAdmin))
|
||||
}
|
@ -0,0 +1,54 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
)
|
||||
|
||||
type nextHTTP = func(http.ResponseWriter, *http.Request)
|
||||
|
||||
// requireAPIEnabled is a middleware that ensures the Administration API
|
||||
// is enabled before servicing requests.
|
||||
func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.auth.IsAdminAPIEnabled() {
|
||||
api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType,
|
||||
"administration API not enabled"))
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// extractAuthorizeTokenAdmin is a middleware that extracts and caches the bearer token.
|
||||
func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
tok := r.Header.Get("Authorization")
|
||||
if tok == "" {
|
||||
api.WriteError(w, admin.NewError(admin.ErrorUnauthorizedType,
|
||||
"missing authorization header token"))
|
||||
return
|
||||
}
|
||||
|
||||
adm, err := h.auth.AuthorizeAdminToken(r, tok)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), adminContextKey, adm)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// ContextKey is the key type for storing and searching for ACME request
|
||||
// essentials in the context of a request.
|
||||
type ContextKey string
|
||||
|
||||
const (
|
||||
// adminContextKey account key
|
||||
adminContextKey = ContextKey("admin")
|
||||
)
|
@ -0,0 +1,175 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"go.step.sm/linkedca"
|
||||
)
|
||||
|
||||
// GetProvisionersResponse is the type for GET /admin/provisioners responses.
|
||||
type GetProvisionersResponse struct {
|
||||
Provisioners provisioner.List `json:"provisioners"`
|
||||
NextCursor string `json:"nextCursor"`
|
||||
}
|
||||
|
||||
// GetProvisioner returns the requested provisioner, or an error.
|
||||
func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
id := r.URL.Query().Get("id")
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
var (
|
||||
p provisioner.Interface
|
||||
err error
|
||||
)
|
||||
if len(id) > 0 {
|
||||
if p, err = h.auth.LoadProvisionerByID(id); err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
prov, err := h.db.GetProvisioner(ctx, p.GetID())
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
api.ProtoJSON(w, prov)
|
||||
}
|
||||
|
||||
// GetProvisioners returns the given segment of provisioners associated with the authority.
|
||||
func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
|
||||
cursor, limit, err := api.ParseCursor(r)
|
||||
if err != nil {
|
||||
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err,
|
||||
"error parsing cursor & limit query params"))
|
||||
return
|
||||
}
|
||||
|
||||
p, next, err := h.auth.GetProvisioners(cursor, limit)
|
||||
if err != nil {
|
||||
api.WriteError(w, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
api.JSON(w, &GetProvisionersResponse{
|
||||
Provisioners: p,
|
||||
NextCursor: next,
|
||||
})
|
||||
}
|
||||
|
||||
// CreateProvisioner creates a new prov.
|
||||
func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
var prov = new(linkedca.Provisioner)
|
||||
if err := api.ReadProtoJSON(r.Body, prov); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Validate inputs
|
||||
if err := authority.ValidateClaims(prov.Claims); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name))
|
||||
return
|
||||
}
|
||||
api.ProtoJSONStatus(w, prov, http.StatusCreated)
|
||||
}
|
||||
|
||||
// DeleteProvisioner deletes a provisioner.
|
||||
func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.URL.Query().Get("id")
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
var (
|
||||
p provisioner.Interface
|
||||
err error
|
||||
)
|
||||
if len(id) > 0 {
|
||||
if p, err = h.auth.LoadProvisionerByID(id); err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName()))
|
||||
return
|
||||
}
|
||||
|
||||
api.JSON(w, &DeleteResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
// UpdateProvisioner updates an existing prov.
|
||||
func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
var nu = new(linkedca.Provisioner)
|
||||
if err := api.ReadProtoJSON(r.Body, nu); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
name := chi.URLParam(r, "name")
|
||||
_old, err := h.auth.LoadProvisionerByName(name)
|
||||
if err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name))
|
||||
return
|
||||
}
|
||||
|
||||
old, err := h.db.GetProvisioner(r.Context(), _old.GetID())
|
||||
if err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID()))
|
||||
return
|
||||
}
|
||||
|
||||
if nu.Id != old.Id {
|
||||
api.WriteError(w, admin.NewErrorISE("cannot change provisioner ID"))
|
||||
return
|
||||
}
|
||||
if nu.Type != old.Type {
|
||||
api.WriteError(w, admin.NewErrorISE("cannot change provisioner type"))
|
||||
return
|
||||
}
|
||||
if nu.AuthorityId != old.AuthorityId {
|
||||
api.WriteError(w, admin.NewErrorISE("cannot change provisioner authorityID"))
|
||||
return
|
||||
}
|
||||
if !nu.CreatedAt.AsTime().Equal(old.CreatedAt.AsTime()) {
|
||||
api.WriteError(w, admin.NewErrorISE("cannot change provisioner createdAt"))
|
||||
return
|
||||
}
|
||||
if !nu.DeletedAt.AsTime().Equal(old.DeletedAt.AsTime()) {
|
||||
api.WriteError(w, admin.NewErrorISE("cannot change provisioner deletedAt"))
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Validate inputs
|
||||
if err := authority.ValidateClaims(nu.Claims); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
api.ProtoJSON(w, nu)
|
||||
}
|
@ -0,0 +1,179 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"go.step.sm/linkedca"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultAuthorityID is the default AuthorityID. This will be the ID
|
||||
// of the first Authority created, as well as the default AuthorityID
|
||||
// if one is not specified in the configuration.
|
||||
DefaultAuthorityID = "00000000-0000-0000-0000-000000000000"
|
||||
)
|
||||
|
||||
// ErrNotFound is an error that should be used by the authority.DB interface to
|
||||
// indicate that an entity does not exist.
|
||||
var ErrNotFound = errors.New("not found")
|
||||
|
||||
// UnmarshalProvisionerDetails unmarshals details type to the specific provisioner details.
|
||||
func UnmarshalProvisionerDetails(typ linkedca.Provisioner_Type, data []byte) (*linkedca.ProvisionerDetails, error) {
|
||||
var v linkedca.ProvisionerDetails
|
||||
switch typ {
|
||||
case linkedca.Provisioner_JWK:
|
||||
v.Data = new(linkedca.ProvisionerDetails_JWK)
|
||||
case linkedca.Provisioner_OIDC:
|
||||
v.Data = new(linkedca.ProvisionerDetails_OIDC)
|
||||
case linkedca.Provisioner_GCP:
|
||||
v.Data = new(linkedca.ProvisionerDetails_GCP)
|
||||
case linkedca.Provisioner_AWS:
|
||||
v.Data = new(linkedca.ProvisionerDetails_AWS)
|
||||
case linkedca.Provisioner_AZURE:
|
||||
v.Data = new(linkedca.ProvisionerDetails_Azure)
|
||||
case linkedca.Provisioner_ACME:
|
||||
v.Data = new(linkedca.ProvisionerDetails_ACME)
|
||||
case linkedca.Provisioner_X5C:
|
||||
v.Data = new(linkedca.ProvisionerDetails_X5C)
|
||||
case linkedca.Provisioner_K8SSA:
|
||||
v.Data = new(linkedca.ProvisionerDetails_K8SSA)
|
||||
case linkedca.Provisioner_SSHPOP:
|
||||
v.Data = new(linkedca.ProvisionerDetails_SSHPOP)
|
||||
case linkedca.Provisioner_SCEP:
|
||||
v.Data = new(linkedca.ProvisionerDetails_SCEP)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported provisioner type %s", typ)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, v.Data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &linkedca.ProvisionerDetails{Data: v.Data}, nil
|
||||
}
|
||||
|
||||
// DB is the DB interface expected by the step-ca Admin API.
|
||||
type DB interface {
|
||||
CreateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error
|
||||
GetProvisioner(ctx context.Context, id string) (*linkedca.Provisioner, error)
|
||||
GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error)
|
||||
UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error
|
||||
DeleteProvisioner(ctx context.Context, id string) error
|
||||
|
||||
CreateAdmin(ctx context.Context, admin *linkedca.Admin) error
|
||||
GetAdmin(ctx context.Context, id string) (*linkedca.Admin, error)
|
||||
GetAdmins(ctx context.Context) ([]*linkedca.Admin, error)
|
||||
UpdateAdmin(ctx context.Context, admin *linkedca.Admin) error
|
||||
DeleteAdmin(ctx context.Context, id string) error
|
||||
}
|
||||
|
||||
// MockDB is an implementation of the DB interface that should only be used as
|
||||
// a mock in tests.
|
||||
type MockDB struct {
|
||||
MockCreateProvisioner func(ctx context.Context, prov *linkedca.Provisioner) error
|
||||
MockGetProvisioner func(ctx context.Context, id string) (*linkedca.Provisioner, error)
|
||||
MockGetProvisioners func(ctx context.Context) ([]*linkedca.Provisioner, error)
|
||||
MockUpdateProvisioner func(ctx context.Context, prov *linkedca.Provisioner) error
|
||||
MockDeleteProvisioner func(ctx context.Context, id string) error
|
||||
|
||||
MockCreateAdmin func(ctx context.Context, adm *linkedca.Admin) error
|
||||
MockGetAdmin func(ctx context.Context, id string) (*linkedca.Admin, error)
|
||||
MockGetAdmins func(ctx context.Context) ([]*linkedca.Admin, error)
|
||||
MockUpdateAdmin func(ctx context.Context, adm *linkedca.Admin) error
|
||||
MockDeleteAdmin func(ctx context.Context, id string) error
|
||||
|
||||
MockError error
|
||||
MockRet1 interface{}
|
||||
}
|
||||
|
||||
// CreateProvisioner mock.
|
||||
func (m *MockDB) CreateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error {
|
||||
if m.MockCreateProvisioner != nil {
|
||||
return m.MockCreateProvisioner(ctx, prov)
|
||||
} else if m.MockError != nil {
|
||||
return m.MockError
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// GetProvisioner mock.
|
||||
func (m *MockDB) GetProvisioner(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
if m.MockGetProvisioner != nil {
|
||||
return m.MockGetProvisioner(ctx, id)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.(*linkedca.Provisioner), m.MockError
|
||||
}
|
||||
|
||||
// GetProvisioners mock
|
||||
func (m *MockDB) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) {
|
||||
if m.MockGetProvisioners != nil {
|
||||
return m.MockGetProvisioners(ctx)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.([]*linkedca.Provisioner), m.MockError
|
||||
}
|
||||
|
||||
// UpdateProvisioner mock
|
||||
func (m *MockDB) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error {
|
||||
if m.MockUpdateProvisioner != nil {
|
||||
return m.MockUpdateProvisioner(ctx, prov)
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// DeleteProvisioner mock
|
||||
func (m *MockDB) DeleteProvisioner(ctx context.Context, id string) error {
|
||||
if m.MockDeleteProvisioner != nil {
|
||||
return m.MockDeleteProvisioner(ctx, id)
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// CreateAdmin mock
|
||||
func (m *MockDB) CreateAdmin(ctx context.Context, admin *linkedca.Admin) error {
|
||||
if m.MockCreateAdmin != nil {
|
||||
return m.MockCreateAdmin(ctx, admin)
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// GetAdmin mock.
|
||||
func (m *MockDB) GetAdmin(ctx context.Context, id string) (*linkedca.Admin, error) {
|
||||
if m.MockGetAdmin != nil {
|
||||
return m.MockGetAdmin(ctx, id)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.(*linkedca.Admin), m.MockError
|
||||
}
|
||||
|
||||
// GetAdmins mock
|
||||
func (m *MockDB) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) {
|
||||
if m.MockGetAdmins != nil {
|
||||
return m.MockGetAdmins(ctx)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return m.MockRet1.([]*linkedca.Admin), m.MockError
|
||||
}
|
||||
|
||||
// UpdateAdmin mock
|
||||
func (m *MockDB) UpdateAdmin(ctx context.Context, adm *linkedca.Admin) error {
|
||||
if m.MockUpdateAdmin != nil {
|
||||
return m.MockUpdateAdmin(ctx, adm)
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
// DeleteAdmin mock
|
||||
func (m *MockDB) DeleteAdmin(ctx context.Context, id string) error {
|
||||
if m.MockDeleteAdmin != nil {
|
||||
return m.MockDeleteAdmin(ctx, id)
|
||||
}
|
||||
return m.MockError
|
||||
}
|
@ -0,0 +1,178 @@
|
||||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/nosql"
|
||||
"go.step.sm/linkedca"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
// dbAdmin is the database representation of the Admin type.
|
||||
type dbAdmin struct {
|
||||
ID string `json:"id"`
|
||||
AuthorityID string `json:"authorityID"`
|
||||
ProvisionerID string `json:"provisionerID"`
|
||||
Subject string `json:"subject"`
|
||||
Type linkedca.Admin_Type `json:"type"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
DeletedAt time.Time `json:"deletedAt"`
|
||||
}
|
||||
|
||||
func (dba *dbAdmin) convert() *linkedca.Admin {
|
||||
return &linkedca.Admin{
|
||||
Id: dba.ID,
|
||||
AuthorityId: dba.AuthorityID,
|
||||
ProvisionerId: dba.ProvisionerID,
|
||||
Subject: dba.Subject,
|
||||
Type: dba.Type,
|
||||
CreatedAt: timestamppb.New(dba.CreatedAt),
|
||||
DeletedAt: timestamppb.New(dba.DeletedAt),
|
||||
}
|
||||
}
|
||||
|
||||
func (dba *dbAdmin) clone() *dbAdmin {
|
||||
u := *dba
|
||||
return &u
|
||||
}
|
||||
|
||||
func (db *DB) getDBAdminBytes(ctx context.Context, id string) ([]byte, error) {
|
||||
data, err := db.db.Get(adminsTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, admin.NewError(admin.ErrorNotFoundType, "admin %s not found", id)
|
||||
} else if err != nil {
|
||||
return nil, errors.Wrapf(err, "error loading admin %s", id)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (db *DB) unmarshalDBAdmin(data []byte, id string) (*dbAdmin, error) {
|
||||
var dba = new(dbAdmin)
|
||||
if err := json.Unmarshal(data, dba); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling admin %s into dbAdmin", id)
|
||||
}
|
||||
if !dba.DeletedAt.IsZero() {
|
||||
return nil, admin.NewError(admin.ErrorDeletedType, "admin %s is deleted", id)
|
||||
}
|
||||
if dba.AuthorityID != db.authorityID {
|
||||
return nil, admin.NewError(admin.ErrorAuthorityMismatchType,
|
||||
"admin %s is not owned by authority %s", dba.ID, db.authorityID)
|
||||
}
|
||||
return dba, nil
|
||||
}
|
||||
|
||||
func (db *DB) getDBAdmin(ctx context.Context, id string) (*dbAdmin, error) {
|
||||
data, err := db.getDBAdminBytes(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dba, err := db.unmarshalDBAdmin(data, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dba, nil
|
||||
}
|
||||
|
||||
func (db *DB) unmarshalAdmin(data []byte, id string) (*linkedca.Admin, error) {
|
||||
dba, err := db.unmarshalDBAdmin(data, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dba.convert(), nil
|
||||
}
|
||||
|
||||
// GetAdmin retrieves and unmarshals a admin from the database.
|
||||
func (db *DB) GetAdmin(ctx context.Context, id string) (*linkedca.Admin, error) {
|
||||
data, err := db.getDBAdminBytes(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
adm, err := db.unmarshalAdmin(data, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return adm, nil
|
||||
}
|
||||
|
||||
// GetAdmins retrieves and unmarshals all active (not deleted) admins
|
||||
// from the database.
|
||||
// TODO should we be paginating?
|
||||
func (db *DB) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) {
|
||||
dbEntries, err := db.db.List(adminsTable)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error loading admins")
|
||||
}
|
||||
var admins = []*linkedca.Admin{}
|
||||
for _, entry := range dbEntries {
|
||||
adm, err := db.unmarshalAdmin(entry.Value, string(entry.Key))
|
||||
if err != nil {
|
||||
switch k := err.(type) {
|
||||
case *admin.Error:
|
||||
if k.IsType(admin.ErrorDeletedType) || k.IsType(admin.ErrorAuthorityMismatchType) {
|
||||
continue
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if adm.AuthorityId != db.authorityID {
|
||||
continue
|
||||
}
|
||||
admins = append(admins, adm)
|
||||
}
|
||||
return admins, nil
|
||||
}
|
||||
|
||||
// CreateAdmin stores a new admin to the database.
|
||||
func (db *DB) CreateAdmin(ctx context.Context, adm *linkedca.Admin) error {
|
||||
var err error
|
||||
adm.Id, err = randID()
|
||||
if err != nil {
|
||||
return admin.WrapErrorISE(err, "error generating random id for admin")
|
||||
}
|
||||
adm.AuthorityId = db.authorityID
|
||||
|
||||
dba := &dbAdmin{
|
||||
ID: adm.Id,
|
||||
AuthorityID: db.authorityID,
|
||||
ProvisionerID: adm.ProvisionerId,
|
||||
Subject: adm.Subject,
|
||||
Type: adm.Type,
|
||||
CreatedAt: clock.Now(),
|
||||
}
|
||||
|
||||
return db.save(ctx, dba.ID, dba, nil, "admin", adminsTable)
|
||||
}
|
||||
|
||||
// UpdateAdmin saves an updated admin to the database.
|
||||
func (db *DB) UpdateAdmin(ctx context.Context, adm *linkedca.Admin) error {
|
||||
old, err := db.getDBAdmin(ctx, adm.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nu := old.clone()
|
||||
nu.Type = adm.Type
|
||||
|
||||
return db.save(ctx, old.ID, nu, old, "admin", adminsTable)
|
||||
}
|
||||
|
||||
// DeleteAdmin saves an updated admin to the database.
|
||||
func (db *DB) DeleteAdmin(ctx context.Context, id string) error {
|
||||
old, err := db.getDBAdmin(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nu := old.clone()
|
||||
nu.DeletedAt = clock.Now()
|
||||
|
||||
return db.save(ctx, old.ID, nu, old, "admin", adminsTable)
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,88 @@
|
||||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
nosqlDB "github.com/smallstep/nosql/database"
|
||||
"go.step.sm/crypto/randutil"
|
||||
)
|
||||
|
||||
var (
|
||||
adminsTable = []byte("admins")
|
||||
provisionersTable = []byte("provisioners")
|
||||
)
|
||||
|
||||
// DB is a struct that implements the AdminDB interface.
|
||||
type DB struct {
|
||||
db nosqlDB.DB
|
||||
authorityID string
|
||||
}
|
||||
|
||||
// New configures and returns a new Authority DB backend implemented using a nosql DB.
|
||||
func New(db nosqlDB.DB, authorityID string) (*DB, error) {
|
||||
tables := [][]byte{adminsTable, provisionersTable}
|
||||
for _, b := range tables {
|
||||
if err := db.CreateTable(b); err != nil {
|
||||
return nil, errors.Wrapf(err, "error creating table %s",
|
||||
string(b))
|
||||
}
|
||||
}
|
||||
return &DB{db, authorityID}, nil
|
||||
}
|
||||
|
||||
// save writes the new data to the database, overwriting the old data if it
|
||||
// existed.
|
||||
func (db *DB) save(ctx context.Context, id string, nu, old interface{}, typ string, table []byte) error {
|
||||
var (
|
||||
err error
|
||||
newB []byte
|
||||
)
|
||||
if nu == nil {
|
||||
newB = nil
|
||||
} else {
|
||||
newB, err = json.Marshal(nu)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "error marshaling authority type: %s, value: %v", typ, nu)
|
||||
}
|
||||
}
|
||||
var oldB []byte
|
||||
if old == nil {
|
||||
oldB = nil
|
||||
} else {
|
||||
oldB, err = json.Marshal(old)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "error marshaling admin type: %s, value: %v", typ, old)
|
||||
}
|
||||
}
|
||||
|
||||
_, swapped, err := db.db.CmpAndSwap(table, []byte(id), oldB, newB)
|
||||
switch {
|
||||
case err != nil:
|
||||
return errors.Wrapf(err, "error saving authority %s", typ)
|
||||
case !swapped:
|
||||
return errors.Errorf("error saving authority %s; changed since last read", typ)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func randID() (val string, err error) {
|
||||
val, err = randutil.UUIDv4()
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "error generating random alphanumeric ID")
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// Clock that returns time in UTC rounded to seconds.
|
||||
type Clock struct{}
|
||||
|
||||
// Now returns the UTC time rounded to seconds.
|
||||
func (c *Clock) Now() time.Time {
|
||||
return time.Now().UTC().Truncate(time.Second)
|
||||
}
|
||||
|
||||
var clock = new(Clock)
|
@ -0,0 +1,211 @@
|
||||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/nosql"
|
||||
"go.step.sm/linkedca"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
// dbProvisioner is the database representation of a Provisioner type.
|
||||
type dbProvisioner struct {
|
||||
ID string `json:"id"`
|
||||
AuthorityID string `json:"authorityID"`
|
||||
Type linkedca.Provisioner_Type `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Claims *linkedca.Claims `json:"claims"`
|
||||
Details []byte `json:"details"`
|
||||
X509Template *linkedca.Template `json:"x509Template"`
|
||||
SSHTemplate *linkedca.Template `json:"sshTemplate"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
DeletedAt time.Time `json:"deletedAt"`
|
||||
}
|
||||
|
||||
func (dbp *dbProvisioner) clone() *dbProvisioner {
|
||||
u := *dbp
|
||||
return &u
|
||||
}
|
||||
|
||||
func (dbp *dbProvisioner) convert2linkedca() (*linkedca.Provisioner, error) {
|
||||
details, err := admin.UnmarshalProvisionerDetails(dbp.Type, dbp.Details)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &linkedca.Provisioner{
|
||||
Id: dbp.ID,
|
||||
AuthorityId: dbp.AuthorityID,
|
||||
Type: dbp.Type,
|
||||
Name: dbp.Name,
|
||||
Claims: dbp.Claims,
|
||||
Details: details,
|
||||
X509Template: dbp.X509Template,
|
||||
SshTemplate: dbp.SSHTemplate,
|
||||
CreatedAt: timestamppb.New(dbp.CreatedAt),
|
||||
DeletedAt: timestamppb.New(dbp.DeletedAt),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (db *DB) getDBProvisionerBytes(ctx context.Context, id string) ([]byte, error) {
|
||||
data, err := db.db.Get(provisionersTable, []byte(id))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found", id)
|
||||
} else if err != nil {
|
||||
return nil, errors.Wrapf(err, "error loading provisioner %s", id)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (db *DB) unmarshalDBProvisioner(data []byte, id string) (*dbProvisioner, error) {
|
||||
var dbp = new(dbProvisioner)
|
||||
if err := json.Unmarshal(data, dbp); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling provisioner %s into dbProvisioner", id)
|
||||
}
|
||||
if !dbp.DeletedAt.IsZero() {
|
||||
return nil, admin.NewError(admin.ErrorDeletedType, "provisioner %s is deleted", id)
|
||||
}
|
||||
if dbp.AuthorityID != db.authorityID {
|
||||
return nil, admin.NewError(admin.ErrorAuthorityMismatchType,
|
||||
"provisioner %s is not owned by authority %s", id, db.authorityID)
|
||||
}
|
||||
return dbp, nil
|
||||
}
|
||||
|
||||
func (db *DB) getDBProvisioner(ctx context.Context, id string) (*dbProvisioner, error) {
|
||||
data, err := db.getDBProvisionerBytes(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dbp, err := db.unmarshalDBProvisioner(data, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dbp, nil
|
||||
}
|
||||
|
||||
func (db *DB) unmarshalProvisioner(data []byte, id string) (*linkedca.Provisioner, error) {
|
||||
dbp, err := db.unmarshalDBProvisioner(data, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return dbp.convert2linkedca()
|
||||
}
|
||||
|
||||
// GetProvisioner retrieves and unmarshals a provisioner from the database.
|
||||
func (db *DB) GetProvisioner(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
data, err := db.getDBProvisionerBytes(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
prov, err := db.unmarshalProvisioner(data, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return prov, nil
|
||||
}
|
||||
|
||||
// GetProvisioners retrieves and unmarshals all active (not deleted) provisioners
|
||||
// from the database.
|
||||
func (db *DB) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) {
|
||||
dbEntries, err := db.db.List(provisionersTable)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error loading provisioners")
|
||||
}
|
||||
var provs []*linkedca.Provisioner
|
||||
for _, entry := range dbEntries {
|
||||
prov, err := db.unmarshalProvisioner(entry.Value, string(entry.Key))
|
||||
if err != nil {
|
||||
switch k := err.(type) {
|
||||
case *admin.Error:
|
||||
if k.IsType(admin.ErrorDeletedType) || k.IsType(admin.ErrorAuthorityMismatchType) {
|
||||
continue
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if prov.AuthorityId != db.authorityID {
|
||||
continue
|
||||
}
|
||||
provs = append(provs, prov)
|
||||
}
|
||||
return provs, nil
|
||||
}
|
||||
|
||||
// CreateProvisioner stores a new provisioner to the database.
|
||||
func (db *DB) CreateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error {
|
||||
var err error
|
||||
prov.Id, err = randID()
|
||||
if err != nil {
|
||||
return admin.WrapErrorISE(err, "error generating random id for provisioner")
|
||||
}
|
||||
|
||||
details, err := json.Marshal(prov.Details.GetData())
|
||||
if err != nil {
|
||||
return admin.WrapErrorISE(err, "error marshaling details when creating provisioner %s", prov.Name)
|
||||
}
|
||||
|
||||
dbp := &dbProvisioner{
|
||||
ID: prov.Id,
|
||||
AuthorityID: db.authorityID,
|
||||
Type: prov.Type,
|
||||
Name: prov.Name,
|
||||
Claims: prov.Claims,
|
||||
Details: details,
|
||||
X509Template: prov.X509Template,
|
||||
SSHTemplate: prov.SshTemplate,
|
||||
CreatedAt: clock.Now(),
|
||||
}
|
||||
|
||||
if err := db.save(ctx, prov.Id, dbp, nil, "provisioner", provisionersTable); err != nil {
|
||||
return admin.WrapErrorISE(err, "error creating provisioner %s", prov.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateProvisioner saves an updated provisioner to the database.
|
||||
func (db *DB) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error {
|
||||
old, err := db.getDBProvisioner(ctx, prov.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nu := old.clone()
|
||||
|
||||
if old.Type != prov.Type {
|
||||
return admin.NewError(admin.ErrorBadRequestType, "cannot update provisioner type")
|
||||
}
|
||||
nu.Name = prov.Name
|
||||
nu.Claims = prov.Claims
|
||||
nu.Details, err = json.Marshal(prov.Details.GetData())
|
||||
if err != nil {
|
||||
return admin.WrapErrorISE(err, "error marshaling details when updating provisioner %s", prov.Name)
|
||||
}
|
||||
nu.X509Template = prov.X509Template
|
||||
nu.SSHTemplate = prov.SshTemplate
|
||||
|
||||
return db.save(ctx, prov.Id, nu, old, "provisioner", provisionersTable)
|
||||
}
|
||||
|
||||
// DeleteProvisioner saves an updated admin to the database.
|
||||
func (db *DB) DeleteProvisioner(ctx context.Context, id string) error {
|
||||
old, err := db.getDBProvisioner(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nu := old.clone()
|
||||
nu.DeletedAt = clock.Now()
|
||||
|
||||
return db.save(ctx, old.ID, nu, old, "provisioner", provisionersTable)
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,223 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
)
|
||||
|
||||
// ProblemType is the type of the Admin problem.
|
||||
type ProblemType int
|
||||
|
||||
const (
|
||||
// ErrorNotFoundType resource not found.
|
||||
ErrorNotFoundType ProblemType = iota
|
||||
// ErrorAuthorityMismatchType resource Authority ID does not match the
|
||||
// context Authority ID.
|
||||
ErrorAuthorityMismatchType
|
||||
// ErrorDeletedType resource has been deleted.
|
||||
ErrorDeletedType
|
||||
// ErrorBadRequestType bad request.
|
||||
ErrorBadRequestType
|
||||
// ErrorNotImplementedType not implemented.
|
||||
ErrorNotImplementedType
|
||||
// ErrorUnauthorizedType internal server error.
|
||||
ErrorUnauthorizedType
|
||||
// ErrorServerInternalType internal server error.
|
||||
ErrorServerInternalType
|
||||
)
|
||||
|
||||
// String returns the string representation of the admin problem type,
|
||||
// fulfilling the Stringer interface.
|
||||
func (ap ProblemType) String() string {
|
||||
switch ap {
|
||||
case ErrorNotFoundType:
|
||||
return "notFound"
|
||||
case ErrorAuthorityMismatchType:
|
||||
return "authorityMismatch"
|
||||
case ErrorDeletedType:
|
||||
return "deleted"
|
||||
case ErrorBadRequestType:
|
||||
return "badRequest"
|
||||
case ErrorNotImplementedType:
|
||||
return "notImplemented"
|
||||
case ErrorUnauthorizedType:
|
||||
return "unauthorized"
|
||||
case ErrorServerInternalType:
|
||||
return "internalServerError"
|
||||
default:
|
||||
return fmt.Sprintf("unsupported error type '%d'", int(ap))
|
||||
}
|
||||
}
|
||||
|
||||
type errorMetadata struct {
|
||||
details string
|
||||
status int
|
||||
typ string
|
||||
String string
|
||||
}
|
||||
|
||||
var (
|
||||
errorServerInternalMetadata = errorMetadata{
|
||||
typ: ErrorServerInternalType.String(),
|
||||
details: "the server experienced an internal error",
|
||||
status: 500,
|
||||
}
|
||||
errorMap = map[ProblemType]errorMetadata{
|
||||
ErrorNotFoundType: {
|
||||
typ: ErrorNotFoundType.String(),
|
||||
details: "resource not found",
|
||||
status: http.StatusNotFound,
|
||||
},
|
||||
ErrorAuthorityMismatchType: {
|
||||
typ: ErrorAuthorityMismatchType.String(),
|
||||
details: "resource not owned by authority",
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
ErrorDeletedType: {
|
||||
typ: ErrorDeletedType.String(),
|
||||
details: "resource is deleted",
|
||||
status: http.StatusNotFound,
|
||||
},
|
||||
ErrorNotImplementedType: {
|
||||
typ: ErrorNotImplementedType.String(),
|
||||
details: "not implemented",
|
||||
status: http.StatusNotImplemented,
|
||||
},
|
||||
ErrorBadRequestType: {
|
||||
typ: ErrorBadRequestType.String(),
|
||||
details: "bad request",
|
||||
status: http.StatusBadRequest,
|
||||
},
|
||||
ErrorUnauthorizedType: {
|
||||
typ: ErrorUnauthorizedType.String(),
|
||||
details: "unauthorized",
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
ErrorServerInternalType: errorServerInternalMetadata,
|
||||
}
|
||||
)
|
||||
|
||||
// Error represents an Admin
|
||||
type Error struct {
|
||||
Type string `json:"type"`
|
||||
Detail string `json:"detail"`
|
||||
Message string `json:"message"`
|
||||
Err error `json:"-"`
|
||||
Status int `json:"-"`
|
||||
}
|
||||
|
||||
// IsType returns true if the error type matches the input type.
|
||||
func (e *Error) IsType(pt ProblemType) bool {
|
||||
return pt.String() == e.Type
|
||||
}
|
||||
|
||||
// NewError creates a new Error type.
|
||||
func NewError(pt ProblemType, msg string, args ...interface{}) *Error {
|
||||
return newError(pt, errors.Errorf(msg, args...))
|
||||
}
|
||||
|
||||
func newError(pt ProblemType, err error) *Error {
|
||||
meta, ok := errorMap[pt]
|
||||
if !ok {
|
||||
meta = errorServerInternalMetadata
|
||||
return &Error{
|
||||
Type: meta.typ,
|
||||
Detail: meta.details,
|
||||
Status: meta.status,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
return &Error{
|
||||
Type: meta.typ,
|
||||
Detail: meta.details,
|
||||
Status: meta.status,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// NewErrorISE creates a new ErrorServerInternalType Error.
|
||||
func NewErrorISE(msg string, args ...interface{}) *Error {
|
||||
return NewError(ErrorServerInternalType, msg, args...)
|
||||
}
|
||||
|
||||
// WrapError attempts to wrap the internal error.
|
||||
func WrapError(typ ProblemType, err error, msg string, args ...interface{}) *Error {
|
||||
switch e := err.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
case *Error:
|
||||
if e.Err == nil {
|
||||
e.Err = errors.Errorf(msg+"; "+e.Detail, args...)
|
||||
} else {
|
||||
e.Err = errors.Wrapf(e.Err, msg, args...)
|
||||
}
|
||||
return e
|
||||
default:
|
||||
return newError(typ, errors.Wrapf(err, msg, args...))
|
||||
}
|
||||
}
|
||||
|
||||
// WrapErrorISE shortcut to wrap an internal server error type.
|
||||
func WrapErrorISE(err error, msg string, args ...interface{}) *Error {
|
||||
return WrapError(ErrorServerInternalType, err, msg, args...)
|
||||
}
|
||||
|
||||
// StatusCode returns the status code and implements the StatusCoder interface.
|
||||
func (e *Error) StatusCode() int {
|
||||
return e.Status
|
||||
}
|
||||
|
||||
// Error allows AError to implement the error interface.
|
||||
func (e *Error) Error() string {
|
||||
return e.Err.Error()
|
||||
}
|
||||
|
||||
// Cause returns the internal error and implements the Causer interface.
|
||||
func (e *Error) Cause() error {
|
||||
if e.Err == nil {
|
||||
return errors.New(e.Detail)
|
||||
}
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// ToLog implements the EnableLogger interface.
|
||||
func (e *Error) ToLog() (interface{}, error) {
|
||||
b, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
return nil, WrapErrorISE(err, "error marshaling authority.Error for logging")
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// WriteError writes to w a JSON representation of the given error.
|
||||
func WriteError(w http.ResponseWriter, err *Error) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(err.StatusCode())
|
||||
|
||||
err.Message = err.Err.Error()
|
||||
// Write errors in the response writer
|
||||
if rl, ok := w.(logging.ResponseLogger); ok {
|
||||
rl.WithFields(map[string]interface{}{
|
||||
"error": err.Err,
|
||||
})
|
||||
if os.Getenv("STEPDEBUG") == "1" {
|
||||
if e, ok := err.Err.(errs.StackTracer); ok {
|
||||
rl.WithFields(map[string]interface{}{
|
||||
"stack-trace": fmt.Sprintf("%+v", e),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(err); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
@ -0,0 +1,243 @@
|
||||
package administrator
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/linkedca"
|
||||
)
|
||||
|
||||
// DefaultAdminLimit is the default limit for listing provisioners.
|
||||
const DefaultAdminLimit = 20
|
||||
|
||||
// DefaultAdminMax is the maximum limit for listing provisioners.
|
||||
const DefaultAdminMax = 100
|
||||
|
||||
type adminSlice []*linkedca.Admin
|
||||
|
||||
func (p adminSlice) Len() int { return len(p) }
|
||||
func (p adminSlice) Less(i, j int) bool { return p[i].Id < p[j].Id }
|
||||
func (p adminSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
||||
|
||||
// Collection is a memory map of admins.
|
||||
type Collection struct {
|
||||
byID *sync.Map
|
||||
bySubProv *sync.Map
|
||||
byProv *sync.Map
|
||||
sorted adminSlice
|
||||
provisioners *provisioner.Collection
|
||||
superCount int
|
||||
superCountByProvisioner map[string]int
|
||||
}
|
||||
|
||||
// NewCollection initializes a collection of provisioners. The given list of
|
||||
// audiences are the audiences used by the JWT provisioner.
|
||||
func NewCollection(provisioners *provisioner.Collection) *Collection {
|
||||
return &Collection{
|
||||
byID: new(sync.Map),
|
||||
byProv: new(sync.Map),
|
||||
bySubProv: new(sync.Map),
|
||||
superCountByProvisioner: map[string]int{},
|
||||
provisioners: provisioners,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadByID a admin by the ID.
|
||||
func (c *Collection) LoadByID(id string) (*linkedca.Admin, bool) {
|
||||
return loadAdmin(c.byID, id)
|
||||
}
|
||||
|
||||
type subProv struct {
|
||||
subject string
|
||||
provisioner string
|
||||
}
|
||||
|
||||
func newSubProv(subject, prov string) subProv {
|
||||
return subProv{subject, prov}
|
||||
}
|
||||
|
||||
// LoadBySubProv a admin by the subject and provisioner name.
|
||||
func (c *Collection) LoadBySubProv(sub, provName string) (*linkedca.Admin, bool) {
|
||||
return loadAdmin(c.bySubProv, newSubProv(sub, provName))
|
||||
}
|
||||
|
||||
// LoadByProvisioner a admin by the subject and provisioner name.
|
||||
func (c *Collection) LoadByProvisioner(provName string) ([]*linkedca.Admin, bool) {
|
||||
val, ok := c.byProv.Load(provName)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
admins, ok := val.([]*linkedca.Admin)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return admins, true
|
||||
}
|
||||
|
||||
// Store adds an admin to the collection and enforces the uniqueness of
|
||||
// admin IDs and amdin subject <-> provisioner name combos.
|
||||
func (c *Collection) Store(adm *linkedca.Admin, prov provisioner.Interface) error {
|
||||
// Input validation.
|
||||
if adm.ProvisionerId != prov.GetID() {
|
||||
return admin.NewErrorISE("admin.provisionerId does not match provisioner argument")
|
||||
}
|
||||
|
||||
// Store admin always in byID. ID must be unique.
|
||||
if _, loaded := c.byID.LoadOrStore(adm.Id, adm); loaded {
|
||||
return errors.New("cannot add multiple admins with the same id")
|
||||
}
|
||||
|
||||
provName := prov.GetName()
|
||||
// Store admin always in bySubProv. Subject <-> ProvisionerName must be unique.
|
||||
if _, loaded := c.bySubProv.LoadOrStore(newSubProv(adm.Subject, provName), adm); loaded {
|
||||
c.byID.Delete(adm.Id)
|
||||
return errors.New("cannot add multiple admins with the same subject and provisioner")
|
||||
}
|
||||
|
||||
var isSuper = (adm.Type == linkedca.Admin_SUPER_ADMIN)
|
||||
if admins, ok := c.LoadByProvisioner(provName); ok {
|
||||
c.byProv.Store(provName, append(admins, adm))
|
||||
if isSuper {
|
||||
c.superCountByProvisioner[provName]++
|
||||
}
|
||||
} else {
|
||||
c.byProv.Store(provName, []*linkedca.Admin{adm})
|
||||
if isSuper {
|
||||
c.superCountByProvisioner[provName] = 1
|
||||
}
|
||||
}
|
||||
if isSuper {
|
||||
c.superCount++
|
||||
}
|
||||
|
||||
c.sorted = append(c.sorted, adm)
|
||||
sort.Sort(c.sorted)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove deletes an admin from all associated collections and lists.
|
||||
func (c *Collection) Remove(id string) error {
|
||||
adm, ok := c.LoadByID(id)
|
||||
if !ok {
|
||||
return admin.NewError(admin.ErrorNotFoundType, "admin %s not found", id)
|
||||
}
|
||||
if adm.Type == linkedca.Admin_SUPER_ADMIN && c.SuperCount() == 1 {
|
||||
return admin.NewError(admin.ErrorBadRequestType, "cannot remove the last super admin")
|
||||
}
|
||||
prov, ok := c.provisioners.Load(adm.ProvisionerId)
|
||||
if !ok {
|
||||
return admin.NewError(admin.ErrorNotFoundType,
|
||||
"provisioner %s for admin %s not found", adm.ProvisionerId, id)
|
||||
}
|
||||
provName := prov.GetName()
|
||||
adminsByProv, ok := c.LoadByProvisioner(provName)
|
||||
if !ok {
|
||||
return admin.NewError(admin.ErrorNotFoundType,
|
||||
"admins not found for provisioner %s", provName)
|
||||
}
|
||||
|
||||
// Find index in sorted list.
|
||||
sortedIndex := sort.Search(c.sorted.Len(), func(i int) bool { return c.sorted[i].Id >= adm.Id })
|
||||
if c.sorted[sortedIndex].Id != adm.Id {
|
||||
return admin.NewError(admin.ErrorNotFoundType,
|
||||
"admin %s not found in sorted list", adm.Id)
|
||||
}
|
||||
|
||||
var found bool
|
||||
for i, a := range adminsByProv {
|
||||
if a.Id == adm.Id {
|
||||
// Remove admin from list. https://stackoverflow.com/questions/37334119/how-to-delete-an-element-from-a-slice-in-golang
|
||||
// Order does not matter.
|
||||
adminsByProv[i] = adminsByProv[len(adminsByProv)-1]
|
||||
c.byProv.Store(provName, adminsByProv[:len(adminsByProv)-1])
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return admin.NewError(admin.ErrorNotFoundType,
|
||||
"admin %s not found in adminsByProvisioner list", adm.Id)
|
||||
}
|
||||
|
||||
// Remove index in sorted list
|
||||
copy(c.sorted[sortedIndex:], c.sorted[sortedIndex+1:]) // Shift a[i+1:] left one index.
|
||||
c.sorted[len(c.sorted)-1] = nil // Erase last element (write zero value).
|
||||
c.sorted = c.sorted[:len(c.sorted)-1] // Truncate slice.
|
||||
|
||||
c.byID.Delete(adm.Id)
|
||||
c.bySubProv.Delete(newSubProv(adm.Subject, provName))
|
||||
|
||||
if adm.Type == linkedca.Admin_SUPER_ADMIN {
|
||||
c.superCount--
|
||||
c.superCountByProvisioner[provName]--
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update updates the given admin in all related lists and collections.
|
||||
func (c *Collection) Update(id string, nu *linkedca.Admin) (*linkedca.Admin, error) {
|
||||
adm, ok := c.LoadByID(id)
|
||||
if !ok {
|
||||
return nil, admin.NewError(admin.ErrorNotFoundType, "admin %s not found", adm.Id)
|
||||
}
|
||||
if adm.Type == nu.Type {
|
||||
return adm, nil
|
||||
}
|
||||
if adm.Type == linkedca.Admin_SUPER_ADMIN && c.SuperCount() == 1 {
|
||||
return nil, admin.NewError(admin.ErrorBadRequestType, "cannot change role of last super admin")
|
||||
}
|
||||
|
||||
adm.Type = nu.Type
|
||||
return adm, nil
|
||||
}
|
||||
|
||||
// SuperCount returns the total number of admins.
|
||||
func (c *Collection) SuperCount() int {
|
||||
return c.superCount
|
||||
}
|
||||
|
||||
// SuperCountByProvisioner returns the total number of admins.
|
||||
func (c *Collection) SuperCountByProvisioner(provName string) int {
|
||||
if cnt, ok := c.superCountByProvisioner[provName]; ok {
|
||||
return cnt
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Find implements pagination on a list of sorted admins.
|
||||
func (c *Collection) Find(cursor string, limit int) ([]*linkedca.Admin, string) {
|
||||
switch {
|
||||
case limit <= 0:
|
||||
limit = DefaultAdminLimit
|
||||
case limit > DefaultAdminMax:
|
||||
limit = DefaultAdminMax
|
||||
}
|
||||
|
||||
n := c.sorted.Len()
|
||||
i := sort.Search(n, func(i int) bool { return c.sorted[i].Id >= cursor })
|
||||
|
||||
slice := []*linkedca.Admin{}
|
||||
for ; i < n && len(slice) < limit; i++ {
|
||||
slice = append(slice, c.sorted[i])
|
||||
}
|
||||
|
||||
if i < n {
|
||||
return slice, c.sorted[i].Id
|
||||
}
|
||||
return slice, ""
|
||||
}
|
||||
|
||||
func loadAdmin(m *sync.Map, key interface{}) (*linkedca.Admin, bool) {
|
||||
val, ok := m.Load(key)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
adm, ok := val.(*linkedca.Admin)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return adm, true
|
||||
}
|
@ -0,0 +1,97 @@
|
||||
package authority
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/linkedca"
|
||||
)
|
||||
|
||||
// LoadAdminByID returns an *linkedca.Admin with the given ID.
|
||||
func (a *Authority) LoadAdminByID(id string) (*linkedca.Admin, bool) {
|
||||
a.adminMutex.RLock()
|
||||
defer a.adminMutex.RUnlock()
|
||||
return a.admins.LoadByID(id)
|
||||
}
|
||||
|
||||
// LoadAdminBySubProv returns an *linkedca.Admin with the given ID.
|
||||
func (a *Authority) LoadAdminBySubProv(subject, prov string) (*linkedca.Admin, bool) {
|
||||
a.adminMutex.RLock()
|
||||
defer a.adminMutex.RUnlock()
|
||||
return a.admins.LoadBySubProv(subject, prov)
|
||||
}
|
||||
|
||||
// GetAdmins returns a map listing each provisioner and the JWK Key Set
|
||||
// with their public keys.
|
||||
func (a *Authority) GetAdmins(cursor string, limit int) ([]*linkedca.Admin, string, error) {
|
||||
a.adminMutex.RLock()
|
||||
defer a.adminMutex.RUnlock()
|
||||
admins, nextCursor := a.admins.Find(cursor, limit)
|
||||
return admins, nextCursor, nil
|
||||
}
|
||||
|
||||
// StoreAdmin stores an *linkedca.Admin to the authority.
|
||||
func (a *Authority) StoreAdmin(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error {
|
||||
a.adminMutex.Lock()
|
||||
defer a.adminMutex.Unlock()
|
||||
|
||||
if adm.ProvisionerId != prov.GetID() {
|
||||
return admin.NewErrorISE("admin.provisionerId does not match provisioner argument")
|
||||
}
|
||||
|
||||
if _, ok := a.admins.LoadBySubProv(adm.Subject, prov.GetName()); ok {
|
||||
return admin.NewError(admin.ErrorBadRequestType,
|
||||
"admin with subject %s and provisioner %s already exists", adm.Subject, prov.GetName())
|
||||
}
|
||||
// Store to database -- this will set the ID.
|
||||
if err := a.adminDB.CreateAdmin(ctx, adm); err != nil {
|
||||
return admin.WrapErrorISE(err, "error creating admin")
|
||||
}
|
||||
if err := a.admins.Store(adm, prov); err != nil {
|
||||
if err := a.reloadAdminResources(ctx); err != nil {
|
||||
return admin.WrapErrorISE(err, "error reloading admin resources on failed admin store")
|
||||
}
|
||||
return admin.WrapErrorISE(err, "error storing admin in authority cache")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateAdmin stores an *linkedca.Admin to the authority.
|
||||
func (a *Authority) UpdateAdmin(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) {
|
||||
a.adminMutex.Lock()
|
||||
defer a.adminMutex.Unlock()
|
||||
adm, err := a.admins.Update(id, nu)
|
||||
if err != nil {
|
||||
return nil, admin.WrapErrorISE(err, "error updating cached admin %s", id)
|
||||
}
|
||||
if err := a.adminDB.UpdateAdmin(ctx, adm); err != nil {
|
||||
if err := a.reloadAdminResources(ctx); err != nil {
|
||||
return nil, admin.WrapErrorISE(err, "error reloading admin resources on failed admin update")
|
||||
}
|
||||
return nil, admin.WrapErrorISE(err, "error updating admin %s", id)
|
||||
}
|
||||
return adm, nil
|
||||
}
|
||||
|
||||
// RemoveAdmin removes an *linkedca.Admin from the authority.
|
||||
func (a *Authority) RemoveAdmin(ctx context.Context, id string) error {
|
||||
a.adminMutex.Lock()
|
||||
defer a.adminMutex.Unlock()
|
||||
|
||||
return a.removeAdmin(ctx, id)
|
||||
}
|
||||
|
||||
// removeAdmin helper that assumes lock.
|
||||
func (a *Authority) removeAdmin(ctx context.Context, id string) error {
|
||||
if err := a.admins.Remove(id); err != nil {
|
||||
return admin.WrapErrorISE(err, "error removing admin %s from authority cache", id)
|
||||
}
|
||||
if err := a.adminDB.DeleteAdmin(ctx, id); err != nil {
|
||||
if err := a.reloadAdminResources(ctx); err != nil {
|
||||
return admin.WrapErrorISE(err, "error reloading admin resources on failed admin remove")
|
||||
}
|
||||
return admin.WrapErrorISE(err, "error deleting admin %s", id)
|
||||
}
|
||||
return nil
|
||||
}
|
@ -1,298 +1,46 @@
|
||||
package authority
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
import "github.com/smallstep/certificates/authority/config"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
cas "github.com/smallstep/certificates/cas/apiv1"
|
||||
"github.com/smallstep/certificates/db"
|
||||
kms "github.com/smallstep/certificates/kms/apiv1"
|
||||
"github.com/smallstep/certificates/templates"
|
||||
)
|
||||
// Config is an alias to support older APIs.
|
||||
type Config = config.Config
|
||||
|
||||
var (
|
||||
// DefaultTLSOptions represents the default TLS version as well as the cipher
|
||||
// suites used in the TLS certificates.
|
||||
DefaultTLSOptions = TLSOptions{
|
||||
CipherSuites: CipherSuites{
|
||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305",
|
||||
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
|
||||
},
|
||||
MinVersion: 1.2,
|
||||
MaxVersion: 1.2,
|
||||
Renegotiation: false,
|
||||
}
|
||||
defaultBackdate = time.Minute
|
||||
defaultDisableRenewal = false
|
||||
defaultEnableSSHCA = false
|
||||
globalProvisionerClaims = provisioner.Claims{
|
||||
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs
|
||||
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs
|
||||
MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour},
|
||||
MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs
|
||||
MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
|
||||
DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
|
||||
EnableSSHCA: &defaultEnableSSHCA,
|
||||
}
|
||||
)
|
||||
// LoadConfiguration is an alias to support older APIs.
|
||||
var LoadConfiguration = config.LoadConfiguration
|
||||
|
||||
// Config represents the CA configuration and it's mapped to a JSON object.
|
||||
type Config struct {
|
||||
Root multiString `json:"root"`
|
||||
FederatedRoots []string `json:"federatedRoots"`
|
||||
IntermediateCert string `json:"crt"`
|
||||
IntermediateKey string `json:"key"`
|
||||
Address string `json:"address"`
|
||||
InsecureAddress string `json:"insecureAddress"`
|
||||
DNSNames []string `json:"dnsNames"`
|
||||
KMS *kms.Options `json:"kms,omitempty"`
|
||||
SSH *SSHConfig `json:"ssh,omitempty"`
|
||||
Logger json.RawMessage `json:"logger,omitempty"`
|
||||
DB *db.Config `json:"db,omitempty"`
|
||||
Monitoring json.RawMessage `json:"monitoring,omitempty"`
|
||||
AuthorityConfig *AuthConfig `json:"authority,omitempty"`
|
||||
TLS *TLSOptions `json:"tls,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Templates *templates.Templates `json:"templates,omitempty"`
|
||||
}
|
||||
// AuthConfig is an alias to support older APIs.
|
||||
type AuthConfig = config.AuthConfig
|
||||
|
||||
// ASN1DN contains ASN1.DN attributes that are used in Subject and Issuer
|
||||
// x509 Certificate blocks.
|
||||
type ASN1DN struct {
|
||||
Country string `json:"country,omitempty" step:"country"`
|
||||
Organization string `json:"organization,omitempty" step:"organization"`
|
||||
OrganizationalUnit string `json:"organizationalUnit,omitempty" step:"organizationalUnit"`
|
||||
Locality string `json:"locality,omitempty" step:"locality"`
|
||||
Province string `json:"province,omitempty" step:"province"`
|
||||
StreetAddress string `json:"streetAddress,omitempty" step:"streetAddress"`
|
||||
CommonName string `json:"commonName,omitempty" step:"commonName"`
|
||||
}
|
||||
// TLS
|
||||
|
||||
// AuthConfig represents the configuration options for the authority. An
|
||||
// underlaying registration authority can also be configured using the
|
||||
// cas.Options.
|
||||
type AuthConfig struct {
|
||||
*cas.Options
|
||||
Provisioners provisioner.List `json:"provisioners"`
|
||||
Template *ASN1DN `json:"template,omitempty"`
|
||||
Claims *provisioner.Claims `json:"claims,omitempty"`
|
||||
DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"`
|
||||
Backdate *provisioner.Duration `json:"backdate,omitempty"`
|
||||
}
|
||||
// ASN1DN is an alias to support older APIs.
|
||||
type ASN1DN = config.ASN1DN
|
||||
|
||||
// init initializes the required fields in the AuthConfig if they are not
|
||||
// provided.
|
||||
func (c *AuthConfig) init() {
|
||||
if c.Provisioners == nil {
|
||||
c.Provisioners = provisioner.List{}
|
||||
}
|
||||
if c.Template == nil {
|
||||
c.Template = &ASN1DN{}
|
||||
}
|
||||
if c.Backdate == nil {
|
||||
c.Backdate = &provisioner.Duration{
|
||||
Duration: defaultBackdate,
|
||||
}
|
||||
}
|
||||
}
|
||||
// DefaultTLSOptions is an alias to support older APIs.
|
||||
var DefaultTLSOptions = config.DefaultTLSOptions
|
||||
|
||||
// Validate validates the authority configuration.
|
||||
func (c *AuthConfig) Validate(audiences provisioner.Audiences) error {
|
||||
if c == nil {
|
||||
return errors.New("authority cannot be undefined")
|
||||
}
|
||||
// TLSOptions is an alias to support older APIs.
|
||||
type TLSOptions = config.TLSOptions
|
||||
|
||||
// Initialize required fields.
|
||||
c.init()
|
||||
// CipherSuites is an alias to support older APIs.
|
||||
type CipherSuites = config.CipherSuites
|
||||
|
||||
// Check that only one K8sSA is enabled
|
||||
var k8sCount int
|
||||
for _, p := range c.Provisioners {
|
||||
if p.GetType() == provisioner.TypeK8sSA {
|
||||
k8sCount++
|
||||
}
|
||||
}
|
||||
if k8sCount > 1 {
|
||||
return errors.New("cannot have more than one kubernetes service account provisioner")
|
||||
}
|
||||
// SSH
|
||||
|
||||
if c.Backdate.Duration < 0 {
|
||||
return errors.New("authority.backdate cannot be less than 0")
|
||||
}
|
||||
// SSHConfig is an alias to support older APIs.
|
||||
type SSHConfig = config.SSHConfig
|
||||
|
||||
return nil
|
||||
}
|
||||
// Bastion is an alias to support older APIs.
|
||||
type Bastion = config.Bastion
|
||||
|
||||
// LoadConfiguration parses the given filename in JSON format and returns the
|
||||
// configuration struct.
|
||||
func LoadConfiguration(filename string) (*Config, error) {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error opening %s", filename)
|
||||
}
|
||||
defer f.Close()
|
||||
// HostTag is an alias to support older APIs.
|
||||
type HostTag = config.HostTag
|
||||
|
||||
var c Config
|
||||
if err := json.NewDecoder(f).Decode(&c); err != nil {
|
||||
return nil, errors.Wrapf(err, "error parsing %s", filename)
|
||||
}
|
||||
// Host is an alias to support older APIs.
|
||||
type Host = config.Host
|
||||
|
||||
c.init()
|
||||
// SSHPublicKey is an alias to support older APIs.
|
||||
type SSHPublicKey = config.SSHPublicKey
|
||||
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
// initializes the minimal configuration required to create an authority. This
|
||||
// is mainly used on embedded authorities.
|
||||
func (c *Config) init() {
|
||||
if c.DNSNames == nil {
|
||||
c.DNSNames = []string{"localhost", "127.0.0.1", "::1"}
|
||||
}
|
||||
if c.TLS == nil {
|
||||
c.TLS = &DefaultTLSOptions
|
||||
}
|
||||
if c.AuthorityConfig == nil {
|
||||
c.AuthorityConfig = &AuthConfig{}
|
||||
}
|
||||
c.AuthorityConfig.init()
|
||||
}
|
||||
|
||||
// Save saves the configuration to the given filename.
|
||||
func (c *Config) Save(filename string) error {
|
||||
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "error opening %s", filename)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
enc := json.NewEncoder(f)
|
||||
enc.SetIndent("", "\t")
|
||||
return errors.Wrapf(enc.Encode(c), "error writing %s", filename)
|
||||
}
|
||||
|
||||
// Validate validates the configuration.
|
||||
func (c *Config) Validate() error {
|
||||
switch {
|
||||
case c.Address == "":
|
||||
return errors.New("address cannot be empty")
|
||||
|
||||
case len(c.DNSNames) == 0:
|
||||
return errors.New("dnsNames cannot be empty")
|
||||
}
|
||||
|
||||
// Options holds the RA/CAS configuration.
|
||||
ra := c.AuthorityConfig.Options
|
||||
// The default RA/CAS requires root, crt and key.
|
||||
if ra.Is(cas.SoftCAS) {
|
||||
switch {
|
||||
case c.Root.HasEmpties():
|
||||
return errors.New("root cannot be empty")
|
||||
case c.IntermediateCert == "":
|
||||
return errors.New("crt cannot be empty")
|
||||
case c.IntermediateKey == "":
|
||||
return errors.New("key cannot be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate address (a port is required)
|
||||
if _, _, err := net.SplitHostPort(c.Address); err != nil {
|
||||
return errors.Errorf("invalid address %s", c.Address)
|
||||
}
|
||||
|
||||
// Validate insecure address if it is configured
|
||||
if c.InsecureAddress != "" {
|
||||
if _, _, err := net.SplitHostPort(c.InsecureAddress); err != nil {
|
||||
return errors.Errorf("invalid address %s", c.InsecureAddress)
|
||||
}
|
||||
}
|
||||
|
||||
if c.TLS == nil {
|
||||
c.TLS = &DefaultTLSOptions
|
||||
} else {
|
||||
if len(c.TLS.CipherSuites) == 0 {
|
||||
c.TLS.CipherSuites = DefaultTLSOptions.CipherSuites
|
||||
}
|
||||
if c.TLS.MaxVersion == 0 {
|
||||
c.TLS.MaxVersion = DefaultTLSOptions.MaxVersion
|
||||
}
|
||||
if c.TLS.MinVersion == 0 {
|
||||
c.TLS.MinVersion = c.TLS.MaxVersion
|
||||
}
|
||||
if c.TLS.MinVersion > c.TLS.MaxVersion {
|
||||
return errors.New("tls minVersion cannot exceed tls maxVersion")
|
||||
}
|
||||
c.TLS.Renegotiation = c.TLS.Renegotiation || DefaultTLSOptions.Renegotiation
|
||||
}
|
||||
|
||||
// Validate KMS options, nil is ok.
|
||||
if err := c.KMS.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate RA/CAS options, nil is ok.
|
||||
if err := ra.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate ssh: nil is ok
|
||||
if err := c.SSH.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate templates: nil is ok
|
||||
if err := c.Templates.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.AuthorityConfig.Validate(c.getAudiences())
|
||||
}
|
||||
|
||||
// getAudiences returns the legacy and possible urls without the ports that will
|
||||
// be used as the default provisioner audiences. The CA might have proxies in
|
||||
// front so we cannot rely on the port.
|
||||
func (c *Config) getAudiences() provisioner.Audiences {
|
||||
audiences := provisioner.Audiences{
|
||||
Sign: []string{legacyAuthority},
|
||||
Revoke: []string{legacyAuthority},
|
||||
SSHSign: []string{},
|
||||
SSHRevoke: []string{},
|
||||
SSHRenew: []string{},
|
||||
}
|
||||
|
||||
for _, name := range c.DNSNames {
|
||||
audiences.Sign = append(audiences.Sign,
|
||||
fmt.Sprintf("https://%s/1.0/sign", name),
|
||||
fmt.Sprintf("https://%s/sign", name),
|
||||
fmt.Sprintf("https://%s/1.0/ssh/sign", name),
|
||||
fmt.Sprintf("https://%s/ssh/sign", name))
|
||||
audiences.Revoke = append(audiences.Revoke,
|
||||
fmt.Sprintf("https://%s/1.0/revoke", name),
|
||||
fmt.Sprintf("https://%s/revoke", name))
|
||||
audiences.SSHSign = append(audiences.SSHSign,
|
||||
fmt.Sprintf("https://%s/1.0/ssh/sign", name),
|
||||
fmt.Sprintf("https://%s/ssh/sign", name),
|
||||
fmt.Sprintf("https://%s/1.0/sign", name),
|
||||
fmt.Sprintf("https://%s/sign", name))
|
||||
audiences.SSHRevoke = append(audiences.SSHRevoke,
|
||||
fmt.Sprintf("https://%s/1.0/ssh/revoke", name),
|
||||
fmt.Sprintf("https://%s/ssh/revoke", name))
|
||||
audiences.SSHRenew = append(audiences.SSHRenew,
|
||||
fmt.Sprintf("https://%s/1.0/ssh/renew", name),
|
||||
fmt.Sprintf("https://%s/ssh/renew", name))
|
||||
audiences.SSHRekey = append(audiences.SSHRekey,
|
||||
fmt.Sprintf("https://%s/1.0/ssh/rekey", name),
|
||||
fmt.Sprintf("https://%s/ssh/rekey", name))
|
||||
}
|
||||
|
||||
return audiences
|
||||
}
|
||||
// SSHKeys is an alias to support older APIs.
|
||||
type SSHKeys = config.SSHKeys
|
||||
|
@ -0,0 +1,297 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
cas "github.com/smallstep/certificates/cas/apiv1"
|
||||
"github.com/smallstep/certificates/db"
|
||||
kms "github.com/smallstep/certificates/kms/apiv1"
|
||||
"github.com/smallstep/certificates/templates"
|
||||
"go.step.sm/linkedca"
|
||||
)
|
||||
|
||||
const (
|
||||
legacyAuthority = "step-certificate-authority"
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultBackdate length of time to backdate certificates to avoid
|
||||
// clock skew validation issues.
|
||||
DefaultBackdate = time.Minute
|
||||
// DefaultDisableRenewal disables renewals per provisioner.
|
||||
DefaultDisableRenewal = false
|
||||
// DefaultEnableSSHCA enable SSH CA features per provisioner or globally
|
||||
// for all provisioners.
|
||||
DefaultEnableSSHCA = false
|
||||
// GlobalProvisionerClaims default claims for the Authority. Can be overridden
|
||||
// by provisioner specific claims.
|
||||
GlobalProvisionerClaims = provisioner.Claims{
|
||||
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs
|
||||
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DisableRenewal: &DefaultDisableRenewal,
|
||||
MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs
|
||||
MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour},
|
||||
MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs
|
||||
MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
|
||||
DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
|
||||
EnableSSHCA: &DefaultEnableSSHCA,
|
||||
}
|
||||
)
|
||||
|
||||
// Config represents the CA configuration and it's mapped to a JSON object.
|
||||
type Config struct {
|
||||
Root multiString `json:"root"`
|
||||
FederatedRoots []string `json:"federatedRoots"`
|
||||
IntermediateCert string `json:"crt"`
|
||||
IntermediateKey string `json:"key"`
|
||||
Address string `json:"address"`
|
||||
InsecureAddress string `json:"insecureAddress"`
|
||||
DNSNames []string `json:"dnsNames"`
|
||||
KMS *kms.Options `json:"kms,omitempty"`
|
||||
SSH *SSHConfig `json:"ssh,omitempty"`
|
||||
Logger json.RawMessage `json:"logger,omitempty"`
|
||||
DB *db.Config `json:"db,omitempty"`
|
||||
Monitoring json.RawMessage `json:"monitoring,omitempty"`
|
||||
AuthorityConfig *AuthConfig `json:"authority,omitempty"`
|
||||
TLS *TLSOptions `json:"tls,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Templates *templates.Templates `json:"templates,omitempty"`
|
||||
}
|
||||
|
||||
// ASN1DN contains ASN1.DN attributes that are used in Subject and Issuer
|
||||
// x509 Certificate blocks.
|
||||
type ASN1DN struct {
|
||||
Country string `json:"country,omitempty"`
|
||||
Organization string `json:"organization,omitempty"`
|
||||
OrganizationalUnit string `json:"organizationalUnit,omitempty"`
|
||||
Locality string `json:"locality,omitempty"`
|
||||
Province string `json:"province,omitempty"`
|
||||
StreetAddress string `json:"streetAddress,omitempty"`
|
||||
SerialNumber string `json:"serialNumber,omitempty"`
|
||||
CommonName string `json:"commonName,omitempty"`
|
||||
}
|
||||
|
||||
// AuthConfig represents the configuration options for the authority. An
|
||||
// underlaying registration authority can also be configured using the
|
||||
// cas.Options.
|
||||
type AuthConfig struct {
|
||||
*cas.Options
|
||||
AuthorityID string `json:"authorityId,omitempty"`
|
||||
DeploymentType string `json:"deploymentType,omitempty"`
|
||||
Provisioners provisioner.List `json:"provisioners,omitempty"`
|
||||
Admins []*linkedca.Admin `json:"-"`
|
||||
Template *ASN1DN `json:"template,omitempty"`
|
||||
Claims *provisioner.Claims `json:"claims,omitempty"`
|
||||
DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"`
|
||||
Backdate *provisioner.Duration `json:"backdate,omitempty"`
|
||||
EnableAdmin bool `json:"enableAdmin,omitempty"`
|
||||
}
|
||||
|
||||
// init initializes the required fields in the AuthConfig if they are not
|
||||
// provided.
|
||||
func (c *AuthConfig) init() {
|
||||
if c.Provisioners == nil {
|
||||
c.Provisioners = provisioner.List{}
|
||||
}
|
||||
if c.Template == nil {
|
||||
c.Template = &ASN1DN{}
|
||||
}
|
||||
if c.Backdate == nil {
|
||||
c.Backdate = &provisioner.Duration{
|
||||
Duration: DefaultBackdate,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate validates the authority configuration.
|
||||
func (c *AuthConfig) Validate(audiences provisioner.Audiences) error {
|
||||
if c == nil {
|
||||
return errors.New("authority cannot be undefined")
|
||||
}
|
||||
|
||||
// Initialize required fields.
|
||||
c.init()
|
||||
|
||||
// Check that only one K8sSA is enabled
|
||||
var k8sCount int
|
||||
for _, p := range c.Provisioners {
|
||||
if p.GetType() == provisioner.TypeK8sSA {
|
||||
k8sCount++
|
||||
}
|
||||
}
|
||||
if k8sCount > 1 {
|
||||
return errors.New("cannot have more than one kubernetes service account provisioner")
|
||||
}
|
||||
|
||||
if c.Backdate.Duration < 0 {
|
||||
return errors.New("authority.backdate cannot be less than 0")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadConfiguration parses the given filename in JSON format and returns the
|
||||
// configuration struct.
|
||||
func LoadConfiguration(filename string) (*Config, error) {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error opening %s", filename)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var c Config
|
||||
if err := json.NewDecoder(f).Decode(&c); err != nil {
|
||||
return nil, errors.Wrapf(err, "error parsing %s", filename)
|
||||
}
|
||||
|
||||
c.Init()
|
||||
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
// Init initializes the minimal configuration required to create an authority. This
|
||||
// is mainly used on embedded authorities.
|
||||
func (c *Config) Init() {
|
||||
if c.DNSNames == nil {
|
||||
c.DNSNames = []string{"localhost", "127.0.0.1", "::1"}
|
||||
}
|
||||
if c.TLS == nil {
|
||||
c.TLS = &DefaultTLSOptions
|
||||
}
|
||||
if c.AuthorityConfig == nil {
|
||||
c.AuthorityConfig = &AuthConfig{}
|
||||
}
|
||||
c.AuthorityConfig.init()
|
||||
}
|
||||
|
||||
// Save saves the configuration to the given filename.
|
||||
func (c *Config) Save(filename string) error {
|
||||
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "error opening %s", filename)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
enc := json.NewEncoder(f)
|
||||
enc.SetIndent("", "\t")
|
||||
return errors.Wrapf(enc.Encode(c), "error writing %s", filename)
|
||||
}
|
||||
|
||||
// Validate validates the configuration.
|
||||
func (c *Config) Validate() error {
|
||||
switch {
|
||||
case c.Address == "":
|
||||
return errors.New("address cannot be empty")
|
||||
case len(c.DNSNames) == 0:
|
||||
return errors.New("dnsNames cannot be empty")
|
||||
case c.AuthorityConfig == nil:
|
||||
return errors.New("authority cannot be nil")
|
||||
}
|
||||
|
||||
// Options holds the RA/CAS configuration.
|
||||
ra := c.AuthorityConfig.Options
|
||||
// The default RA/CAS requires root, crt and key.
|
||||
if ra.Is(cas.SoftCAS) {
|
||||
switch {
|
||||
case c.Root.HasEmpties():
|
||||
return errors.New("root cannot be empty")
|
||||
case c.IntermediateCert == "":
|
||||
return errors.New("crt cannot be empty")
|
||||
case c.IntermediateKey == "":
|
||||
return errors.New("key cannot be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate address (a port is required)
|
||||
if _, _, err := net.SplitHostPort(c.Address); err != nil {
|
||||
return errors.Errorf("invalid address %s", c.Address)
|
||||
}
|
||||
|
||||
if c.TLS == nil {
|
||||
c.TLS = &DefaultTLSOptions
|
||||
} else {
|
||||
if len(c.TLS.CipherSuites) == 0 {
|
||||
c.TLS.CipherSuites = DefaultTLSOptions.CipherSuites
|
||||
}
|
||||
if c.TLS.MaxVersion == 0 {
|
||||
c.TLS.MaxVersion = DefaultTLSOptions.MaxVersion
|
||||
}
|
||||
if c.TLS.MinVersion == 0 {
|
||||
c.TLS.MinVersion = DefaultTLSOptions.MinVersion
|
||||
}
|
||||
if c.TLS.MinVersion > c.TLS.MaxVersion {
|
||||
return errors.New("tls minVersion cannot exceed tls maxVersion")
|
||||
}
|
||||
c.TLS.Renegotiation = c.TLS.Renegotiation || DefaultTLSOptions.Renegotiation
|
||||
}
|
||||
|
||||
// Validate KMS options, nil is ok.
|
||||
if err := c.KMS.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate RA/CAS options, nil is ok.
|
||||
if err := ra.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate ssh: nil is ok
|
||||
if err := c.SSH.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate templates: nil is ok
|
||||
if err := c.Templates.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.AuthorityConfig.Validate(c.GetAudiences())
|
||||
}
|
||||
|
||||
// GetAudiences returns the legacy and possible urls without the ports that will
|
||||
// be used as the default provisioner audiences. The CA might have proxies in
|
||||
// front so we cannot rely on the port.
|
||||
func (c *Config) GetAudiences() provisioner.Audiences {
|
||||
audiences := provisioner.Audiences{
|
||||
Sign: []string{legacyAuthority},
|
||||
Revoke: []string{legacyAuthority},
|
||||
SSHSign: []string{},
|
||||
SSHRevoke: []string{},
|
||||
SSHRenew: []string{},
|
||||
}
|
||||
|
||||
for _, name := range c.DNSNames {
|
||||
audiences.Sign = append(audiences.Sign,
|
||||
fmt.Sprintf("https://%s/1.0/sign", name),
|
||||
fmt.Sprintf("https://%s/sign", name),
|
||||
fmt.Sprintf("https://%s/1.0/ssh/sign", name),
|
||||
fmt.Sprintf("https://%s/ssh/sign", name))
|
||||
audiences.Revoke = append(audiences.Revoke,
|
||||
fmt.Sprintf("https://%s/1.0/revoke", name),
|
||||
fmt.Sprintf("https://%s/revoke", name))
|
||||
audiences.SSHSign = append(audiences.SSHSign,
|
||||
fmt.Sprintf("https://%s/1.0/ssh/sign", name),
|
||||
fmt.Sprintf("https://%s/ssh/sign", name),
|
||||
fmt.Sprintf("https://%s/1.0/sign", name),
|
||||
fmt.Sprintf("https://%s/sign", name))
|
||||
audiences.SSHRevoke = append(audiences.SSHRevoke,
|
||||
fmt.Sprintf("https://%s/1.0/ssh/revoke", name),
|
||||
fmt.Sprintf("https://%s/ssh/revoke", name))
|
||||
audiences.SSHRenew = append(audiences.SSHRenew,
|
||||
fmt.Sprintf("https://%s/1.0/ssh/renew", name),
|
||||
fmt.Sprintf("https://%s/ssh/renew", name))
|
||||
audiences.SSHRekey = append(audiences.SSHRekey,
|
||||
fmt.Sprintf("https://%s/1.0/ssh/rekey", name),
|
||||
fmt.Sprintf("https://%s/ssh/rekey", name))
|
||||
}
|
||||
|
||||
return audiences
|
||||
}
|
@ -0,0 +1,94 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/crypto/jose"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// SSHConfig contains the user and host keys.
|
||||
type SSHConfig struct {
|
||||
HostKey string `json:"hostKey"`
|
||||
UserKey string `json:"userKey"`
|
||||
Keys []*SSHPublicKey `json:"keys,omitempty"`
|
||||
AddUserPrincipal string `json:"addUserPrincipal,omitempty"`
|
||||
AddUserCommand string `json:"addUserCommand,omitempty"`
|
||||
Bastion *Bastion `json:"bastion,omitempty"`
|
||||
}
|
||||
|
||||
// Bastion contains the custom properties used on bastion.
|
||||
type Bastion struct {
|
||||
Hostname string `json:"hostname"`
|
||||
User string `json:"user,omitempty"`
|
||||
Port string `json:"port,omitempty"`
|
||||
Command string `json:"cmd,omitempty"`
|
||||
Flags string `json:"flags,omitempty"`
|
||||
}
|
||||
|
||||
// HostTag are tagged with k,v pairs. These tags are how a user is ultimately
|
||||
// associated with a host.
|
||||
type HostTag struct {
|
||||
ID string
|
||||
Name string
|
||||
Value string
|
||||
}
|
||||
|
||||
// Host defines expected attributes for an ssh host.
|
||||
type Host struct {
|
||||
HostID string `json:"hid"`
|
||||
HostTags []HostTag `json:"host_tags"`
|
||||
Hostname string `json:"hostname"`
|
||||
}
|
||||
|
||||
// Validate checks the fields in SSHConfig.
|
||||
func (c *SSHConfig) Validate() error {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
for _, k := range c.Keys {
|
||||
if err := k.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SSHPublicKey contains a public key used by federated CAs to keep old signing
|
||||
// keys for this ca.
|
||||
type SSHPublicKey struct {
|
||||
Type string `json:"type"`
|
||||
Federated bool `json:"federated"`
|
||||
Key jose.JSONWebKey `json:"key"`
|
||||
publicKey ssh.PublicKey
|
||||
}
|
||||
|
||||
// Validate checks the fields in SSHPublicKey.
|
||||
func (k *SSHPublicKey) Validate() error {
|
||||
switch {
|
||||
case k.Type == "":
|
||||
return errors.New("type cannot be empty")
|
||||
case k.Type != provisioner.SSHHostCert && k.Type != provisioner.SSHUserCert:
|
||||
return errors.Errorf("invalid type %s, it must be user or host", k.Type)
|
||||
case !k.Key.IsPublic():
|
||||
return errors.New("invalid key type, it must be a public key")
|
||||
}
|
||||
|
||||
key, err := ssh.NewPublicKey(k.Key.Key)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error creating ssh key")
|
||||
}
|
||||
k.publicKey = key
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublicKey returns the ssh public key.
|
||||
func (k *SSHPublicKey) PublicKey() ssh.PublicKey {
|
||||
return k.publicKey
|
||||
}
|
||||
|
||||
// SSHKeys represents the SSH User and Host public keys.
|
||||
type SSHKeys struct {
|
||||
UserKeys []ssh.PublicKey
|
||||
HostKeys []ssh.PublicKey
|
||||
}
|
@ -0,0 +1,73 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"go.step.sm/crypto/jose"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestSSHPublicKey_Validate(t *testing.T) {
|
||||
key, err := jose.GenerateJWK("EC", "P-256", "", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type fields struct {
|
||||
Type string
|
||||
Federated bool
|
||||
Key jose.JSONWebKey
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{"user", fields{"user", true, key.Public()}, false},
|
||||
{"host", fields{"host", false, key.Public()}, false},
|
||||
{"empty", fields{"", true, key.Public()}, true},
|
||||
{"badType", fields{"bad", false, key.Public()}, true},
|
||||
{"badKey", fields{"user", false, *key}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &SSHPublicKey{
|
||||
Type: tt.fields.Type,
|
||||
Federated: tt.fields.Federated,
|
||||
Key: tt.fields.Key,
|
||||
}
|
||||
if err := k.Validate(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("SSHPublicKey.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHPublicKey_PublicKey(t *testing.T) {
|
||||
key, err := jose.GenerateJWK("EC", "P-256", "", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
pub, err := ssh.NewPublicKey(key.Public().Key)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type fields struct {
|
||||
publicKey ssh.PublicKey
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want ssh.PublicKey
|
||||
}{
|
||||
{"ok", fields{pub}, pub},
|
||||
{"nil", fields{nil}, nil},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &SSHPublicKey{
|
||||
publicKey: tt.fields.publicKey,
|
||||
}
|
||||
if got := k.PublicKey(); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("SSHPublicKey.PublicKey() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package authority
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
@ -1,4 +1,4 @@
|
||||
package authority
|
||||
package config
|
||||
|
||||
import (
|
||||
"reflect"
|
@ -0,0 +1,284 @@
|
||||
package authority
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/cli-utils/config"
|
||||
"go.step.sm/linkedca"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
)
|
||||
|
||||
// Export creates a linkedca configuration form the current ca.json and loaded
|
||||
// authorities.
|
||||
//
|
||||
// Note that export will not export neither the pki password nor the certificate
|
||||
// issuer password.
|
||||
func (a *Authority) Export() (c *linkedca.Configuration, err error) {
|
||||
// Recover from panics
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = r.(error)
|
||||
}
|
||||
}()
|
||||
|
||||
files := make(map[string][]byte)
|
||||
|
||||
// The exported configuration should not include the password in it.
|
||||
c = &linkedca.Configuration{
|
||||
Version: "1.0",
|
||||
Root: mustReadFilesOrURIs(a.config.Root, files),
|
||||
FederatedRoots: mustReadFilesOrURIs(a.config.FederatedRoots, files),
|
||||
Intermediate: mustReadFileOrURI(a.config.IntermediateCert, files),
|
||||
IntermediateKey: mustReadFileOrURI(a.config.IntermediateKey, files),
|
||||
Address: a.config.Address,
|
||||
InsecureAddress: a.config.InsecureAddress,
|
||||
DnsNames: a.config.DNSNames,
|
||||
Db: mustMarshalToStruct(a.config.DB),
|
||||
Logger: mustMarshalToStruct(a.config.Logger),
|
||||
Monitoring: mustMarshalToStruct(a.config.Monitoring),
|
||||
Authority: &linkedca.Authority{
|
||||
Id: a.config.AuthorityConfig.AuthorityID,
|
||||
EnableAdmin: a.config.AuthorityConfig.EnableAdmin,
|
||||
DisableIssuedAtCheck: a.config.AuthorityConfig.DisableIssuedAtCheck,
|
||||
Backdate: mustDuration(a.config.AuthorityConfig.Backdate),
|
||||
DeploymentType: a.config.AuthorityConfig.DeploymentType,
|
||||
},
|
||||
Files: files,
|
||||
}
|
||||
|
||||
// SSH
|
||||
if v := a.config.SSH; v != nil {
|
||||
c.Ssh = &linkedca.SSH{
|
||||
HostKey: mustReadFileOrURI(v.HostKey, files),
|
||||
UserKey: mustReadFileOrURI(v.UserKey, files),
|
||||
AddUserPrincipal: v.AddUserPrincipal,
|
||||
AddUserCommand: v.AddUserCommand,
|
||||
}
|
||||
for _, k := range v.Keys {
|
||||
typ, ok := linkedca.SSHPublicKey_Type_value[strings.ToUpper(k.Type)]
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unsupported ssh key type %s", k.Type)
|
||||
}
|
||||
c.Ssh.Keys = append(c.Ssh.Keys, &linkedca.SSHPublicKey{
|
||||
Type: linkedca.SSHPublicKey_Type(typ),
|
||||
Federated: k.Federated,
|
||||
Key: mustMarshalToStruct(k),
|
||||
})
|
||||
}
|
||||
if b := v.Bastion; b != nil {
|
||||
c.Ssh.Bastion = &linkedca.Bastion{
|
||||
Hostname: b.Hostname,
|
||||
User: b.User,
|
||||
Port: b.Port,
|
||||
Command: b.Command,
|
||||
Flags: b.Flags,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// KMS
|
||||
if v := a.config.KMS; v != nil {
|
||||
var typ int32
|
||||
var ok bool
|
||||
if v.Type == "" {
|
||||
typ = int32(linkedca.KMS_SOFTKMS)
|
||||
} else {
|
||||
typ, ok = linkedca.KMS_Type_value[strings.ToUpper(v.Type)]
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unsupported kms type %s", v.Type)
|
||||
}
|
||||
}
|
||||
c.Kms = &linkedca.KMS{
|
||||
Type: linkedca.KMS_Type(typ),
|
||||
CredentialsFile: v.CredentialsFile,
|
||||
Uri: v.URI,
|
||||
Pin: v.Pin,
|
||||
ManagementKey: v.ManagementKey,
|
||||
Region: v.Region,
|
||||
Profile: v.Profile,
|
||||
}
|
||||
}
|
||||
|
||||
// Authority
|
||||
// cas options
|
||||
if v := a.config.AuthorityConfig.Options; v != nil {
|
||||
c.Authority.Type = 0
|
||||
c.Authority.CertificateAuthority = v.CertificateAuthority
|
||||
c.Authority.CertificateAuthorityFingerprint = v.CertificateAuthorityFingerprint
|
||||
c.Authority.CredentialsFile = v.CredentialsFile
|
||||
if iss := v.CertificateIssuer; iss != nil {
|
||||
typ, ok := linkedca.CertificateIssuer_Type_value[strings.ToUpper(iss.Type)]
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unknown certificate issuer type %s", iss.Type)
|
||||
}
|
||||
// The exported certificate issuer should not include the password.
|
||||
c.Authority.CertificateIssuer = &linkedca.CertificateIssuer{
|
||||
Type: linkedca.CertificateIssuer_Type(typ),
|
||||
Provisioner: iss.Provisioner,
|
||||
Certificate: mustReadFileOrURI(iss.Certificate, files),
|
||||
Key: mustReadFileOrURI(iss.Key, files),
|
||||
}
|
||||
}
|
||||
}
|
||||
// admins
|
||||
for {
|
||||
list, cursor := a.admins.Find("", 100)
|
||||
c.Authority.Admins = append(c.Authority.Admins, list...)
|
||||
if cursor == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
// provisioners
|
||||
for {
|
||||
list, cursor := a.provisioners.Find("", 100)
|
||||
for _, p := range list {
|
||||
lp, err := ProvisionerToLinkedca(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.Authority.Provisioners = append(c.Authority.Provisioners, lp)
|
||||
}
|
||||
if cursor == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
// global claims
|
||||
c.Authority.Claims = claimsToLinkedca(a.config.AuthorityConfig.Claims)
|
||||
// Distinguished names template
|
||||
if v := a.config.AuthorityConfig.Template; v != nil {
|
||||
c.Authority.Template = &linkedca.DistinguishedName{
|
||||
Country: v.Country,
|
||||
Organization: v.Organization,
|
||||
OrganizationalUnit: v.OrganizationalUnit,
|
||||
Locality: v.Locality,
|
||||
Province: v.Province,
|
||||
StreetAddress: v.StreetAddress,
|
||||
SerialNumber: v.SerialNumber,
|
||||
CommonName: v.CommonName,
|
||||
}
|
||||
}
|
||||
|
||||
// TLS
|
||||
if v := a.config.TLS; v != nil {
|
||||
c.Tls = &linkedca.TLS{
|
||||
MinVersion: v.MinVersion.String(),
|
||||
MaxVersion: v.MaxVersion.String(),
|
||||
Renegotiation: v.Renegotiation,
|
||||
}
|
||||
for _, cs := range v.CipherSuites.Value() {
|
||||
c.Tls.CipherSuites = append(c.Tls.CipherSuites, linkedca.TLS_CiperSuite(cs))
|
||||
}
|
||||
}
|
||||
|
||||
// Templates
|
||||
if v := a.config.Templates; v != nil {
|
||||
c.Templates = &linkedca.ConfigTemplates{
|
||||
Ssh: &linkedca.SSHConfigTemplate{},
|
||||
Data: mustMarshalToStruct(v.Data),
|
||||
}
|
||||
// Remove automatically loaded vars
|
||||
if c.Templates.Data != nil && c.Templates.Data.Fields != nil {
|
||||
delete(c.Templates.Data.Fields, "Step")
|
||||
}
|
||||
for _, t := range v.SSH.Host {
|
||||
typ, ok := linkedca.ConfigTemplate_Type_value[strings.ToUpper(string(t.Type))]
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unsupported template type %s", t.Type)
|
||||
}
|
||||
c.Templates.Ssh.Hosts = append(c.Templates.Ssh.Hosts, &linkedca.ConfigTemplate{
|
||||
Type: linkedca.ConfigTemplate_Type(typ),
|
||||
Name: t.Name,
|
||||
Template: mustReadFileOrURI(t.TemplatePath, files),
|
||||
Path: t.Path,
|
||||
Comment: t.Comment,
|
||||
Requires: t.RequiredData,
|
||||
Content: t.Content,
|
||||
})
|
||||
}
|
||||
for _, t := range v.SSH.User {
|
||||
typ, ok := linkedca.ConfigTemplate_Type_value[strings.ToUpper(string(t.Type))]
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unsupported template type %s", t.Type)
|
||||
}
|
||||
c.Templates.Ssh.Users = append(c.Templates.Ssh.Users, &linkedca.ConfigTemplate{
|
||||
Type: linkedca.ConfigTemplate_Type(typ),
|
||||
Name: t.Name,
|
||||
Template: mustReadFileOrURI(t.TemplatePath, files),
|
||||
Path: t.Path,
|
||||
Comment: t.Comment,
|
||||
Requires: t.RequiredData,
|
||||
Content: t.Content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func mustDuration(d *provisioner.Duration) string {
|
||||
if d == nil || d.Duration == 0 {
|
||||
return ""
|
||||
}
|
||||
return d.String()
|
||||
}
|
||||
|
||||
func mustMarshalToStruct(v interface{}) *structpb.Struct {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
panic(errors.Wrapf(err, "error marshaling %T", v))
|
||||
}
|
||||
var r *structpb.Struct
|
||||
if err := json.Unmarshal(b, &r); err != nil {
|
||||
panic(errors.Wrapf(err, "error unmarshaling %T", v))
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func mustReadFileOrURI(fn string, m map[string][]byte) string {
|
||||
if fn == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
stepPath := filepath.ToSlash(config.StepPath())
|
||||
if !strings.HasSuffix(stepPath, "/") {
|
||||
stepPath += "/"
|
||||
}
|
||||
|
||||
fn = strings.TrimPrefix(filepath.ToSlash(fn), stepPath)
|
||||
|
||||
ok, err := isFilename(fn)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if ok {
|
||||
b, err := ioutil.ReadFile(config.StepAbs(fn))
|
||||
if err != nil {
|
||||
panic(errors.Wrapf(err, "error reading %s", fn))
|
||||
}
|
||||
m[fn] = b
|
||||
return fn
|
||||
}
|
||||
return fn
|
||||
}
|
||||
|
||||
func mustReadFilesOrURIs(fns []string, m map[string][]byte) []string {
|
||||
var result []string
|
||||
for _, fn := range fns {
|
||||
result = append(result, mustReadFileOrURI(fn, m))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func isFilename(fn string) (bool, error) {
|
||||
u, err := url.Parse(fn)
|
||||
if err != nil {
|
||||
return false, errors.Wrapf(err, "error parsing %s", fn)
|
||||
}
|
||||
return u.Scheme == "" || u.Scheme == "file", nil
|
||||
}
|
@ -0,0 +1,490 @@
|
||||
package authority
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
"go.step.sm/crypto/tlsutil"
|
||||
"go.step.sm/crypto/x509util"
|
||||
"go.step.sm/linkedca"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
||||
const uuidPattern = "^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}$"
|
||||
|
||||
type linkedCaClient struct {
|
||||
renewer *tlsutil.Renewer
|
||||
client linkedca.MajordomoClient
|
||||
authorityID string
|
||||
}
|
||||
|
||||
type linkedCAClaims struct {
|
||||
jose.Claims
|
||||
SANs []string `json:"sans"`
|
||||
SHA string `json:"sha"`
|
||||
}
|
||||
|
||||
func newLinkedCAClient(token string) (*linkedCaClient, error) {
|
||||
tok, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error parsing token")
|
||||
}
|
||||
|
||||
var claims linkedCAClaims
|
||||
if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||
return nil, errors.Wrap(err, "error parsing token")
|
||||
}
|
||||
// Validate claims
|
||||
if len(claims.Audience) != 1 {
|
||||
return nil, errors.New("error parsing token: invalid aud claim")
|
||||
}
|
||||
if claims.SHA == "" {
|
||||
return nil, errors.New("error parsing token: invalid sha claim")
|
||||
}
|
||||
// Get linkedCA endpoint from audience.
|
||||
u, err := url.Parse(claims.Audience[0])
|
||||
if err != nil {
|
||||
return nil, errors.New("error parsing token: invalid aud claim")
|
||||
}
|
||||
// Get authority from SANs
|
||||
authority, err := getAuthority(claims.SANs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create csr to login with
|
||||
signer, err := keyutil.GenerateDefaultSigner()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
csr, err := x509util.CreateCertificateRequest(claims.Subject, claims.SANs, signer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get and verify root certificate
|
||||
root, err := getRootCertificate(u.Host, claims.SHA)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(root)
|
||||
|
||||
// Login with majordomo and get certificates
|
||||
cert, tlsConfig, err := login(authority, token, csr, signer, u.Host, pool)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Start TLS renewer and set the GetClientCertificate callback to it.
|
||||
renewer, err := tlsutil.NewRenewer(cert, tlsConfig, func() (*tls.Certificate, *tls.Config, error) {
|
||||
return login(authority, token, csr, signer, u.Host, pool)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
||||
|
||||
// Start mTLS client
|
||||
conn, err := grpc.Dial(u.Host, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error connecting %s", u.Host)
|
||||
}
|
||||
|
||||
return &linkedCaClient{
|
||||
renewer: renewer,
|
||||
client: linkedca.NewMajordomoClient(conn),
|
||||
authorityID: authority,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) Run() {
|
||||
c.renewer.Run()
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) Stop() {
|
||||
c.renewer.Stop()
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) CreateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error {
|
||||
resp, err := c.client.CreateProvisioner(ctx, &linkedca.CreateProvisionerRequest{
|
||||
Type: prov.Type,
|
||||
Name: prov.Name,
|
||||
Details: prov.Details,
|
||||
Claims: prov.Claims,
|
||||
X509Template: prov.X509Template,
|
||||
SshTemplate: prov.SshTemplate,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error creating provisioner")
|
||||
}
|
||||
prov.Id = resp.Id
|
||||
prov.AuthorityId = resp.AuthorityId
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) GetProvisioner(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
resp, err := c.client.GetProvisioner(ctx, &linkedca.GetProvisionerRequest{
|
||||
Id: id,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error getting provisioners")
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) {
|
||||
resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{
|
||||
AuthorityId: c.authorityID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error getting provisioners")
|
||||
}
|
||||
return resp.Provisioners, nil
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error {
|
||||
_, err := c.client.UpdateProvisioner(ctx, &linkedca.UpdateProvisionerRequest{
|
||||
Id: prov.Id,
|
||||
Name: prov.Name,
|
||||
Details: prov.Details,
|
||||
Claims: prov.Claims,
|
||||
X509Template: prov.X509Template,
|
||||
SshTemplate: prov.SshTemplate,
|
||||
})
|
||||
return errors.Wrap(err, "error updating provisioner")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) DeleteProvisioner(ctx context.Context, id string) error {
|
||||
_, err := c.client.DeleteProvisioner(ctx, &linkedca.DeleteProvisionerRequest{
|
||||
Id: id,
|
||||
})
|
||||
return errors.Wrap(err, "error deleting provisioner")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) CreateAdmin(ctx context.Context, adm *linkedca.Admin) error {
|
||||
resp, err := c.client.CreateAdmin(ctx, &linkedca.CreateAdminRequest{
|
||||
Subject: adm.Subject,
|
||||
ProvisionerId: adm.ProvisionerId,
|
||||
Type: adm.Type,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error creating admin")
|
||||
}
|
||||
adm.Id = resp.Id
|
||||
adm.AuthorityId = resp.AuthorityId
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) GetAdmin(ctx context.Context, id string) (*linkedca.Admin, error) {
|
||||
resp, err := c.client.GetAdmin(ctx, &linkedca.GetAdminRequest{
|
||||
Id: id,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error getting admins")
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) {
|
||||
resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{
|
||||
AuthorityId: c.authorityID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error getting admins")
|
||||
}
|
||||
return resp.Admins, nil
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) UpdateAdmin(ctx context.Context, adm *linkedca.Admin) error {
|
||||
_, err := c.client.UpdateAdmin(ctx, &linkedca.UpdateAdminRequest{
|
||||
Id: adm.Id,
|
||||
Type: adm.Type,
|
||||
})
|
||||
return errors.Wrap(err, "error updating admin")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) DeleteAdmin(ctx context.Context, id string) error {
|
||||
_, err := c.client.DeleteAdmin(ctx, &linkedca.DeleteAdminRequest{
|
||||
Id: id,
|
||||
})
|
||||
return errors.Wrap(err, "error deleting admin")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) StoreCertificateChain(fullchain ...*x509.Certificate) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
_, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{
|
||||
PemCertificate: serializeCertificateChain(fullchain[0]),
|
||||
PemCertificateChain: serializeCertificateChain(fullchain[1:]...),
|
||||
})
|
||||
return errors.Wrap(err, "error posting certificate")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) StoreRenewedCertificate(parent *x509.Certificate, fullchain ...*x509.Certificate) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
_, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{
|
||||
PemCertificate: serializeCertificateChain(fullchain[0]),
|
||||
PemCertificateChain: serializeCertificateChain(fullchain[1:]...),
|
||||
PemParentCertificate: serializeCertificateChain(parent),
|
||||
})
|
||||
return errors.Wrap(err, "error posting certificate")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) StoreSSHCertificate(crt *ssh.Certificate) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
_, err := c.client.PostSSHCertificate(ctx, &linkedca.SSHCertificateRequest{
|
||||
Certificate: string(ssh.MarshalAuthorizedKey(crt)),
|
||||
})
|
||||
return errors.Wrap(err, "error posting ssh certificate")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) Revoke(crt *x509.Certificate, rci *db.RevokedCertificateInfo) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
_, err := c.client.RevokeCertificate(ctx, &linkedca.RevokeCertificateRequest{
|
||||
Serial: rci.Serial,
|
||||
PemCertificate: serializeCertificate(crt),
|
||||
Reason: rci.Reason,
|
||||
ReasonCode: linkedca.RevocationReasonCode(rci.ReasonCode),
|
||||
Passive: true,
|
||||
})
|
||||
|
||||
return errors.Wrap(err, "error revoking certificate")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) RevokeSSH(cert *ssh.Certificate, rci *db.RevokedCertificateInfo) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
_, err := c.client.RevokeSSHCertificate(ctx, &linkedca.RevokeSSHCertificateRequest{
|
||||
Serial: rci.Serial,
|
||||
Certificate: serializeSSHCertificate(cert),
|
||||
Reason: rci.Reason,
|
||||
ReasonCode: linkedca.RevocationReasonCode(rci.ReasonCode),
|
||||
Passive: true,
|
||||
})
|
||||
|
||||
return errors.Wrap(err, "error revoking ssh certificate")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) IsRevoked(serial string) (bool, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
resp, err := c.client.GetCertificateStatus(ctx, &linkedca.GetCertificateStatusRequest{
|
||||
Serial: serial,
|
||||
})
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "error getting certificate status")
|
||||
}
|
||||
return resp.Status != linkedca.RevocationStatus_ACTIVE, nil
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) IsSSHRevoked(serial string) (bool, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
resp, err := c.client.GetSSHCertificateStatus(ctx, &linkedca.GetSSHCertificateStatusRequest{
|
||||
Serial: serial,
|
||||
})
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "error getting certificate status")
|
||||
}
|
||||
return resp.Status != linkedca.RevocationStatus_ACTIVE, nil
|
||||
}
|
||||
|
||||
func serializeCertificate(crt *x509.Certificate) string {
|
||||
if crt == nil {
|
||||
return ""
|
||||
}
|
||||
return string(pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: crt.Raw,
|
||||
}))
|
||||
}
|
||||
|
||||
func serializeCertificateChain(fullchain ...*x509.Certificate) string {
|
||||
var chain string
|
||||
for _, crt := range fullchain {
|
||||
chain += string(pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: crt.Raw,
|
||||
}))
|
||||
}
|
||||
return chain
|
||||
}
|
||||
|
||||
func serializeSSHCertificate(crt *ssh.Certificate) string {
|
||||
if crt == nil {
|
||||
return ""
|
||||
}
|
||||
return string(ssh.MarshalAuthorizedKey(crt))
|
||||
}
|
||||
|
||||
func getAuthority(sans []string) (string, error) {
|
||||
for _, s := range sans {
|
||||
if strings.HasPrefix(s, "urn:smallstep:authority:") {
|
||||
if regexp.MustCompile(uuidPattern).MatchString(s[24:]) {
|
||||
return s[24:], nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("error parsing token: invalid sans claim")
|
||||
}
|
||||
|
||||
// getRootCertificate creates an insecure majordomo client and returns the
|
||||
// verified root certificate.
|
||||
func getRootCertificate(endpoint, fingerprint string) (*x509.Certificate, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
conn, err := grpc.DialContext(ctx, endpoint, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
})))
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error connecting %s", endpoint)
|
||||
}
|
||||
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
client := linkedca.NewMajordomoClient(conn)
|
||||
resp, err := client.GetRootCertificate(ctx, &linkedca.GetRootCertificateRequest{
|
||||
Fingerprint: fingerprint,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting root certificate: %w", err)
|
||||
}
|
||||
|
||||
var block *pem.Block
|
||||
b := []byte(resp.PemCertificate)
|
||||
for len(b) > 0 {
|
||||
block, b = pem.Decode(b)
|
||||
if block == nil {
|
||||
break
|
||||
}
|
||||
if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing certificate: %w", err)
|
||||
}
|
||||
|
||||
// verify the sha256
|
||||
sum := sha256.Sum256(cert.Raw)
|
||||
if !strings.EqualFold(fingerprint, hex.EncodeToString(sum[:])) {
|
||||
return nil, fmt.Errorf("error verifying certificate: SHA256 fingerprint does not match")
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("error getting root certificate: certificate not found")
|
||||
}
|
||||
|
||||
// login creates a new majordomo client with just the root ca pool and returns
|
||||
// the signed certificate and tls configuration.
|
||||
func login(authority, token string, csr *x509.CertificateRequest, signer crypto.PrivateKey, endpoint string, rootCAs *x509.CertPool) (*tls.Certificate, *tls.Config, error) {
|
||||
// Connect to majordomo
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
conn, err := grpc.DialContext(ctx, endpoint, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
|
||||
RootCAs: rootCAs,
|
||||
})))
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrapf(err, "error connecting %s", endpoint)
|
||||
}
|
||||
|
||||
// Login to get the signed certificate
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
client := linkedca.NewMajordomoClient(conn)
|
||||
resp, err := client.Login(ctx, &linkedca.LoginRequest{
|
||||
AuthorityId: authority,
|
||||
Token: token,
|
||||
PemCertificateRequest: string(pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE REQUEST",
|
||||
Bytes: csr.Raw,
|
||||
})),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrapf(err, "error logging in %s", endpoint)
|
||||
}
|
||||
|
||||
// Parse login response
|
||||
var block *pem.Block
|
||||
var bundle []*x509.Certificate
|
||||
rest := []byte(resp.PemCertificateChain)
|
||||
for {
|
||||
block, rest = pem.Decode(rest)
|
||||
if block == nil {
|
||||
break
|
||||
}
|
||||
if block.Type != "CERTIFICATE" {
|
||||
return nil, nil, errors.New("error decoding login response: pemCertificateChain is not a certificate bundle")
|
||||
}
|
||||
crt, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "error parsing login response")
|
||||
}
|
||||
bundle = append(bundle, crt)
|
||||
}
|
||||
if len(bundle) == 0 {
|
||||
return nil, nil, errors.New("error decoding login response: pemCertificateChain should not be empty")
|
||||
}
|
||||
|
||||
// Build tls.Certificate with PemCertificate and intermediates in the
|
||||
// PemCertificateChain
|
||||
cert := &tls.Certificate{
|
||||
PrivateKey: signer,
|
||||
}
|
||||
rest = []byte(resp.PemCertificate)
|
||||
for {
|
||||
block, rest = pem.Decode(rest)
|
||||
if block == nil {
|
||||
break
|
||||
}
|
||||
if block.Type == "CERTIFICATE" {
|
||||
leaf, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "error parsing pemCertificate")
|
||||
}
|
||||
cert.Certificate = append(cert.Certificate, block.Bytes)
|
||||
cert.Leaf = leaf
|
||||
}
|
||||
}
|
||||
|
||||
// Add intermediates to the tls.Certificate
|
||||
last := len(bundle) - 1
|
||||
for i := 0; i < last; i++ {
|
||||
cert.Certificate = append(cert.Certificate, bundle[i].Raw)
|
||||
}
|
||||
|
||||
// Add root to the pool if it's not there yet
|
||||
rootCAs.AddCert(bundle[last])
|
||||
|
||||
return cert, &tls.Config{
|
||||
RootCAs: rootCAs,
|
||||
}, nil
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue