Simplify SSH modifiers with options.

It also changes the behavior of the request options to modify only
the validity of the certificate.
pull/329/head
Mariano Cano 4 years ago
parent df1f7e5a2e
commit c1fc45c872

@ -24,14 +24,7 @@ const (
// certificate.
type SSHCertModifier interface {
SignOption
Modify(cert *ssh.Certificate) error
}
// SSHCertOptionModifier is the interface used to add custom options used
// to modify the SSH certificate.
type SSHCertOptionModifier interface {
SignOption
Option(o SignSSHOptions) SSHCertModifier
Modify(cert *ssh.Certificate, opts SignSSHOptions) error
}
// SSHCertValidator is the interface used to validate an SSH certificate.
@ -47,14 +40,6 @@ type SSHCertOptionsValidator interface {
Valid(got SignSSHOptions) error
}
// sshModifierFunc is an adapter to allow the use of ordinary functions as SSH
// certificate modifiers.
type sshModifierFunc func(cert *ssh.Certificate) error
func (f sshModifierFunc) Modify(cert *ssh.Certificate) error {
return f(cert)
}
// SignSSHOptions contains the options that can be passed to the SignSSH method.
type SignSSHOptions struct {
CertType string `json:"certType"`
@ -72,7 +57,7 @@ func (o SignSSHOptions) Type() uint32 {
}
// Modify implements SSHCertModifier and sets the SSHOption in the ssh.Certificate.
func (o SignSSHOptions) Modify(cert *ssh.Certificate) error {
func (o SignSSHOptions) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
switch o.CertType {
case "": // ignore
case SSHUserCert:
@ -86,6 +71,12 @@ func (o SignSSHOptions) Modify(cert *ssh.Certificate) error {
cert.KeyId = o.KeyID
cert.ValidPrincipals = o.Principals
return o.ModifyValidity(cert)
}
// ModifyValidity modifies only the ValidAfter and ValidBefore on the given
// ssh.Certificate.
func (o SignSSHOptions) ModifyValidity(cert *ssh.Certificate) error {
t := now()
if !o.ValidAfter.IsZero() {
cert.ValidAfter = uint64(o.ValidAfter.RelativeTime(t).Unix())
@ -96,7 +87,6 @@ func (o SignSSHOptions) Modify(cert *ssh.Certificate) error {
if cert.ValidAfter > 0 && cert.ValidBefore > 0 && cert.ValidAfter > cert.ValidBefore {
return errors.New("ssh certificate valid after cannot be greater than valid before")
}
return nil
}
@ -123,7 +113,7 @@ func (o SignSSHOptions) match(got SignSSHOptions) error {
type sshCertPrincipalsModifier []string
// Modify the ValidPrincipals value of the cert.
func (o sshCertPrincipalsModifier) Modify(cert *ssh.Certificate) error {
func (o sshCertPrincipalsModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
cert.ValidPrincipals = []string(o)
return nil
}
@ -132,7 +122,7 @@ func (o sshCertPrincipalsModifier) Modify(cert *ssh.Certificate) error {
// Key ID in the SSH certificate.
type sshCertKeyIDModifier string
func (m sshCertKeyIDModifier) Modify(cert *ssh.Certificate) error {
func (m sshCertKeyIDModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
cert.KeyId = string(m)
return nil
}
@ -142,7 +132,7 @@ func (m sshCertKeyIDModifier) Modify(cert *ssh.Certificate) error {
type sshCertTypeModifier string
// Modify sets the CertType for the ssh certificate.
func (m sshCertTypeModifier) Modify(cert *ssh.Certificate) error {
func (m sshCertTypeModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
cert.CertType = sshCertTypeUInt32(string(m))
return nil
}
@ -151,7 +141,7 @@ func (m sshCertTypeModifier) Modify(cert *ssh.Certificate) error {
// ValidAfter in the SSH certificate.
type sshCertValidAfterModifier uint64
func (m sshCertValidAfterModifier) Modify(cert *ssh.Certificate) error {
func (m sshCertValidAfterModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
cert.ValidAfter = uint64(m)
return nil
}
@ -160,7 +150,7 @@ func (m sshCertValidAfterModifier) Modify(cert *ssh.Certificate) error {
// ValidBefore in the SSH certificate.
type sshCertValidBeforeModifier uint64
func (m sshCertValidBeforeModifier) Modify(cert *ssh.Certificate) error {
func (m sshCertValidBeforeModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
cert.ValidBefore = uint64(m)
return nil
}
@ -217,27 +207,27 @@ type sshDefaultDuration struct {
*Claimer
}
func (m *sshDefaultDuration) Option(o SignSSHOptions) SSHCertModifier {
return sshModifierFunc(func(cert *ssh.Certificate) error {
d, err := m.DefaultSSHCertDuration(cert.CertType)
if err != nil {
return err
}
// Modify implements SSHCertModifier and sets the validity if it has not been
// set, but it always applies the backdate.
func (m *sshDefaultDuration) Modify(cert *ssh.Certificate, o SignSSHOptions) error {
d, err := m.DefaultSSHCertDuration(cert.CertType)
if err != nil {
return err
}
var backdate uint64
if cert.ValidAfter == 0 {
backdate = uint64(o.Backdate / time.Second)
cert.ValidAfter = uint64(now().Truncate(time.Second).Unix())
}
if cert.ValidBefore == 0 {
cert.ValidBefore = cert.ValidAfter + uint64(d/time.Second)
}
// Apply backdate safely
if cert.ValidAfter > backdate {
cert.ValidAfter -= backdate
}
return nil
})
var backdate uint64
if cert.ValidAfter == 0 {
backdate = uint64(o.Backdate / time.Second)
cert.ValidAfter = uint64(now().Truncate(time.Second).Unix())
}
if cert.ValidBefore == 0 {
cert.ValidBefore = cert.ValidAfter + uint64(d/time.Second)
}
// Apply backdate safely
if cert.ValidAfter > backdate {
cert.ValidAfter -= backdate
}
return nil
}
// sshLimitDuration adjusts the duration to min(default, remaining provisioning
@ -250,51 +240,52 @@ type sshLimitDuration struct {
NotAfter time.Time
}
func (m *sshLimitDuration) Option(o SignSSHOptions) SSHCertModifier {
// Modify implements SSHCertModifier and modifies the validity of the
// certificate to expire before the configured limit.
func (m *sshLimitDuration) Modify(cert *ssh.Certificate, o SignSSHOptions) error {
if m.NotAfter.IsZero() {
defaultDuration := &sshDefaultDuration{m.Claimer}
return defaultDuration.Option(o)
return defaultDuration.Modify(cert, o)
}
return sshModifierFunc(func(cert *ssh.Certificate) error {
d, err := m.DefaultSSHCertDuration(cert.CertType)
if err != nil {
return err
}
// Make sure the duration is within the limits.
d, err := m.DefaultSSHCertDuration(cert.CertType)
if err != nil {
return err
}
var backdate uint64
if cert.ValidAfter == 0 {
backdate = uint64(o.Backdate / time.Second)
cert.ValidAfter = uint64(now().Truncate(time.Second).Unix())
}
var backdate uint64
if cert.ValidAfter == 0 {
backdate = uint64(o.Backdate / time.Second)
cert.ValidAfter = uint64(now().Truncate(time.Second).Unix())
}
certValidAfter := time.Unix(int64(cert.ValidAfter), 0)
if certValidAfter.After(m.NotAfter) {
return errors.Errorf("provisioning credential expiration (%s) is before requested certificate validAfter (%s)",
m.NotAfter, certValidAfter)
}
certValidAfter := time.Unix(int64(cert.ValidAfter), 0)
if certValidAfter.After(m.NotAfter) {
return errors.Errorf("provisioning credential expiration (%s) is before requested certificate validAfter (%s)",
m.NotAfter, certValidAfter)
}
if cert.ValidBefore == 0 {
certValidBefore := certValidAfter.Add(d)
if m.NotAfter.Before(certValidBefore) {
certValidBefore = m.NotAfter
}
cert.ValidBefore = uint64(certValidBefore.Unix())
} else {
certValidBefore := time.Unix(int64(cert.ValidBefore), 0)
if m.NotAfter.Before(certValidBefore) {
return errors.Errorf("provisioning credential expiration (%s) is before requested certificate validBefore (%s)",
m.NotAfter, certValidBefore)
}
if cert.ValidBefore == 0 {
certValidBefore := certValidAfter.Add(d)
if m.NotAfter.Before(certValidBefore) {
certValidBefore = m.NotAfter
}
// Apply backdate safely
if cert.ValidAfter > backdate {
cert.ValidAfter -= backdate
cert.ValidBefore = uint64(certValidBefore.Unix())
} else {
certValidBefore := time.Unix(int64(cert.ValidBefore), 0)
if m.NotAfter.Before(certValidBefore) {
return errors.Errorf("provisioning credential expiration (%s) is before requested certificate validBefore (%s)",
m.NotAfter, certValidBefore)
}
}
return nil
})
// Apply backdate safely
if cert.ValidAfter > backdate {
cert.ValidAfter -= backdate
}
return nil
}
// sshCertOptionsValidator validates the user SSHOptions with the ones

@ -206,6 +206,8 @@ func (a *Authority) GetSSHBastion(ctx context.Context, user string, hostname str
// SignSSH creates a signed SSH certificate with the given public key and options.
func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
var (
err error
certType sshutil.CertType
certOptions []sshutil.Option
mods []provisioner.SSHCertModifier
validators []provisioner.SSHCertValidator
@ -214,6 +216,14 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
// Set backdate with the configured value
opts.Backdate = a.config.AuthorityConfig.Backdate.Duration
// Validate certificate type.
if opts.CertType != "" {
certType, err = sshutil.CertTypeFromString(opts.CertType)
if err != nil {
return nil, errs.Wrap(http.StatusBadRequest, err, "authority.SignSSH")
}
}
for _, op := range signOpts {
switch o := op.(type) {
// add options to NewCertificate
@ -224,10 +234,6 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
case provisioner.SSHCertModifier:
mods = append(mods, o)
// modify the ssh.Certificate given the SSHOptions
case provisioner.SSHCertOptionModifier:
mods = append(mods, o.Option(opts))
// validate the ssh.Certificate
case provisioner.SSHCertValidator:
validators = append(validators, o)
@ -235,16 +241,24 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
// validate the given SSHOptions
case provisioner.SSHCertOptionsValidator:
if err := o.Valid(opts); err != nil {
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
return nil, errs.Wrap(http.StatusForbidden, err, "authority.SignSSH")
}
default:
return nil, errs.InternalServer("signSSH: invalid extra option type %T", o)
return nil, errs.InternalServer("authority.SignSSH: invalid extra option type %T", o)
}
}
// Simulated certificate request with request options.
cr := sshutil.CertificateRequest{
Type: certType,
KeyID: opts.KeyID,
Principals: opts.Principals,
Key: key,
}
// Create certificate from template.
certificate, err := sshutil.NewCertificate(key, certOptions...)
certificate, err := sshutil.NewCertificate(cr, certOptions...)
if err != nil {
if _, ok := err.(*sshutil.TemplateError); ok {
return nil, errs.NewErr(http.StatusBadRequest, err,
@ -255,19 +269,19 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH")
}
// Get actual *ssh.Certificate and continue with user and provisioner
// modifiers.
// Get actual *ssh.Certificate and continue with provisioner modifiers.
cert := certificate.GetCertificate()
// Use SignSSHOptions to modify the certificate.
if err := opts.Modify(cert); err != nil {
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
// Use SignSSHOptions to modify the certificate validity. It will be later
// checked or set if not defined.
if err := opts.ModifyValidity(cert); err != nil {
return nil, errs.Wrap(http.StatusBadRequest, err, "authority.SignSSH")
}
// Use provisioner modifiers.
for _, m := range mods {
if err := m.Modify(cert); err != nil {
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
if err := m.Modify(cert, opts); err != nil {
return nil, errs.Wrap(http.StatusForbidden, err, "authority.SignSSH")
}
}
@ -276,33 +290,33 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
switch cert.CertType {
case ssh.UserCert:
if a.sshCAUserCertSignKey == nil {
return nil, errs.NotImplemented("signSSH: user certificate signing is not enabled")
return nil, errs.NotImplemented("authority.SignSSH: user certificate signing is not enabled")
}
signer = a.sshCAUserCertSignKey
case ssh.HostCert:
if a.sshCAHostCertSignKey == nil {
return nil, errs.NotImplemented("signSSH: host certificate signing is not enabled")
return nil, errs.NotImplemented("authority.SignSSH: host certificate signing is not enabled")
}
signer = a.sshCAHostCertSignKey
default:
return nil, errs.InternalServer("signSSH: unexpected ssh certificate type: %d", cert.CertType)
return nil, errs.InternalServer("authority.SignSSH: unexpected ssh certificate type: %d", cert.CertType)
}
// Sign certificate.
cert, err = sshutil.CreateCertificate(cert, signer)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error signing certificate")
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH: error signing certificate")
}
// User provisioners validators.
for _, v := range validators {
if err := v.Valid(cert, opts); err != nil {
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
return nil, errs.Wrap(http.StatusForbidden, err, "authority.SignSSH")
}
}
if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error storing certificate in db")
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH: error storing certificate in db")
}
return cert, nil

Loading…
Cancel
Save